Unverified Commit ce163f9e authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support SP + RoPE + GeLU (#602)



* Adding support of sequence parallelism
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding RoPE
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix wrong batch_logical_axes
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Rnaming FSDP outer env var
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Poring RoPE to Praxis layers.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Porting GeLU + [FP8 Cast].
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Allowing arbitrary dimension of NVShape for the workspace allocation
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Modify with review feedback.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix bugs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed for lint
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Follow review feedback to modify code.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Port SP to Praxis
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix an issue when enabling both GQA and RoPE.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update docs
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
parent 29b0c9ca
......@@ -6,10 +6,23 @@
Jax
=======
.. autoapiclass:: transformer_engine.jax.MajorShardingType
.. autoapiclass:: transformer_engine.jax.ShardingType
Pre-defined Variable of Logical Axes
------------------------------------
Variables are available in `transformer_engine.jax.sharding`.
* BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh.
* SEQLEN_AXES: The logical axis of sequence length dimension. It is usually not sharded.
* SEQLEN_TP_AXES: The logical axis of sequence length dimension. It is usually sharded along TP on Mesh.
* HEAD_AXES: The logical axis of head dimension of MHA. It is usually sharded along TP on Mesh.
* HIDDEN_AXES: The logical axis of hidden dimension. It is usually not sharded.
* HIDDEN_TP_AXES: The logical axis of hidden dimension. It is usually sharded along TP on Mesh.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.
Modules
------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)
.. autoapiclass:: transformer_engine.jax.MeshResource()
.. autoapifunction:: transformer_engine.jax.fp8_autocast
......
......@@ -35,6 +35,7 @@ INPUT_KEY = 'input_rng'
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
enable_seq_paral: bool
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
......@@ -50,11 +51,17 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
x = jax.lax.with_sharding_constraint(x,
jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None))
x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
......@@ -266,7 +273,7 @@ def train_and_evaluate(args):
with te.fp8_autocast(args.use_fp8,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)):
encoder = Net(num_embed)
encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
......@@ -379,6 +386,10 @@ def encoder_parser(args):
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
parser.add_argument("--enable-sp",
action="store_true",
default=False,
help="Enable sequence parallelism.")
return parser.parse_args(args)
......@@ -405,6 +416,20 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_sp(self):
"""Test Transformer Engine with FP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -14,6 +14,8 @@ from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import dgelu, dgelu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import gelu, gelu_fp8
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
......@@ -21,6 +23,7 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp
GEMM_CASES = [
(256, 256, 512),
......@@ -285,6 +288,126 @@ class TestFP8Dot:
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_gelu_fp8_mlp(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
activations = ('gelu',)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
(FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(
layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm"))
def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)
fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
amax[:FP8Helper.NUM_META_PER_GEMM],
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = jax.nn.gelu(linear_1_out)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
amax[FP8Helper.NUM_META_PER_GEMM:],
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)
return output
def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
return jnp.mean(
ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))
value_n_grad_primitive_func = jit(
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
ref_fp8_max = init_fp8_max
ref_fp8_metas_amax = init_fp8_metas_amax
ref_fp8_metas_scale = init_fp8_metas_scale
ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv
pri_fp8_max = init_fp8_max
pri_fp8_metas_amax = init_fp8_metas_amax
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, pri_fp8_max,
pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, s, k1, k2, b1, b2, pri_fp8_max, pri_fp8_metas_amax,
pri_fp8_metas_scale, pri_fp8_metas_scale_inv)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
......@@ -294,6 +417,96 @@ def random_inputs_fixture(shape):
return out
class TestGeLu:
def ref_func(self, inputs):
func = jit(value_and_grad(lambda x: jnp.mean(jax.nn.gelu(x))))
return func(inputs)
def prim_func(self, inputs):
@jax.custom_vjp
def primitive(x):
out, _ = primitive_fwd(x)
return out
def primitive_fwd(x):
out = gelu(x)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
out = dgelu(g, x)
return (out,)
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x: jnp.mean(primitive(x)))
return func(inputs)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
x = random_inputs
prim_out, prim_grad = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
class TestGeLuFP8(TestGeLu):
def prim_func(self, inputs):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
no_use = jnp.zeros(1, jnp.float32)
@jax.custom_vjp
def primitive(x, y, z, w):
out = primitive_fwd(x)
return out
def primitive_fwd(x, y, z, w):
out, _ = gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
out = dequantize(out, x.dtype, scale_inv)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
dgelu, dgelu_trans, dbias, amax_out = dgelu_dbias_cast_transpose(
g, x, amax, scale, scale_inv, jnp.float8_e5m2, -1)
dgelu = dequantize(dgelu, x.dtype, scale_inv)
dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
return dgelu, dgelu_trans, dbias, amax_out
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3))
return func(inputs, no_use, no_use, no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
x = random_inputs
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, (2, 0, 1)),
dtype=FP8Helper.BWD_DTYPE)
class TestGatedGeLu:
def ref_func(self, inputs):
......
......@@ -88,6 +88,7 @@ _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
......@@ -137,7 +138,25 @@ ATTRS = [{
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu',)),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......
......@@ -818,12 +818,14 @@ class TransformerLayerAttr:
LYR_TYPE = 'layer_type'
ZERO_CEN = 'zero_centered_gamma'
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -831,6 +833,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -838,6 +841,7 @@ class TransformerLayerAttr:
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -845,6 +849,7 @@ class TransformerLayerAttr:
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -852,6 +857,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -859,6 +865,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -866,6 +873,7 @@ class TransformerLayerAttr:
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -873,6 +881,7 @@ class TransformerLayerAttr:
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -880,6 +889,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -887,6 +897,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -894,6 +905,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -901,6 +913,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -908,6 +921,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -915,6 +929,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -922,6 +937,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -929,6 +945,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -936,6 +953,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -943,6 +961,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -950,6 +969,7 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -957,6 +977,23 @@ class TransformerLayerAttr:
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
TRANSPOSE_BS: False
}]
......@@ -984,6 +1021,7 @@ class TestTransformer(TestLayer):
use_bias = attrs[TransformerLayerAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases,
num_attention_heads=num_attention_heads)
......@@ -1019,6 +1057,7 @@ class TestTransformer(TestLayer):
bias_init=bias_init,
layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
......@@ -1040,6 +1079,7 @@ class TestTransformer(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init(
"bias", bias_init),
layer_type=layer_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=drop_path,
......
......@@ -340,6 +340,29 @@ class MlpBlock(nn.Module):
return output
def apply_rotary_pos_emb(
inputs: jnp.ndarray,
position: jnp.ndarray,
min_timescale: int = 1,
max_timescale: int = 10000,
):
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
timescale = min_timescale * (max_timescale / min_timescale)**fraction
timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
first_part = first_part.astype(inputs.dtype)
second_part = second_part.astype(inputs.dtype)
return jnp.concatenate([first_part, second_part], axis=-1)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
......@@ -368,6 +391,7 @@ class MultiHeadAttention(nn.Module):
float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False
scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False
fuse_qkv: bool = True
def __post_init__(self):
......@@ -482,6 +506,15 @@ class MultiHeadAttention(nn.Module):
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
if self.enable_rotary_pos_emb:
batch_dim = 1 if self.transpose_batch_sequence else 0
seq_dim = 1 - batch_dim
position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
query = apply_rotary_pos_emb(query, position)
key = apply_rotary_pos_emb(key, position)
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
......@@ -802,6 +835,7 @@ class EncoderLayer(nn.Module):
zero_centered_gamma: bool = False
output_layernorm: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
......@@ -854,6 +888,7 @@ class EncoderLayer(nn.Module):
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
name='attention')(x,
x,
encoder_mask,
......@@ -922,6 +957,7 @@ class DecoderLayer(nn.Module):
layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
......@@ -981,6 +1017,7 @@ class DecoderLayer(nn.Module):
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params,
name='self_attention')(x,
x,
......@@ -1014,6 +1051,7 @@ class DecoderLayer(nn.Module):
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params,
name='encoder_decoder_attention')(y,
encoded,
......
......@@ -1349,12 +1349,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2,
......@@ -1396,6 +1390,12 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
return;
}
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
......
......@@ -313,13 +313,14 @@ class LayerNormFwdPrimitive(BasePrimitive):
assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
True, kwargs['zero_centered_gamma'], kwargs['epsilon']
)
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
True,
kwargs['zero_centered_gamma'],
kwargs['epsilon'])
wkspace_aval = out_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = out_aval.update(shape=barrier_info[0],
......@@ -384,14 +385,14 @@ class LayerNormFwdPrimitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma,
epsilon,
sm_margin,
......@@ -523,7 +524,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0],
dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]))
dbeta_part_aval = dbeta_aval.update(shape=dbeta_part_info[0],
dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]))
dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]))
return dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, barrier_aval, \
dgamma_part_aval, dbeta_part_aval
......@@ -559,7 +560,6 @@ class LayerNormBwdPrimitive(BasePrimitive):
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
......@@ -706,13 +706,14 @@ class RmsNormFwdPrimitive(BasePrimitive):
assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
False, False, kwargs['epsilon']
)
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
False,
False,
kwargs['epsilon'])
wkspace_aval = out_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = out_aval.update(shape=barrier_info[0],
......@@ -764,14 +765,14 @@ class RmsNormFwdPrimitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
......@@ -936,13 +937,13 @@ class RmsNormBwdPrimitive(BasePrimitive):
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.size,
0, # no dbeta_part for RMSnorm
0, # no dbeta_part for RMSnorm
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
TEDType.kByte, # dummy dbeta_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
......@@ -1906,10 +1907,8 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, batch_shape)
wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes(
batch_size, max_seqlen, num_heads, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training
)
batch_size, max_seqlen, num_heads, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
wkspace_aval = qkv_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
......@@ -2271,8 +2270,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability,
num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
attn_bias_type, attn_mask_type, dropout_probability, num_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
......@@ -2298,8 +2297,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
......@@ -2336,9 +2334,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen,
num_heads, num_gqa_groups, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
......@@ -2532,9 +2529,8 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen,
num_heads, num_gqa_groups, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
......@@ -2666,6 +2662,222 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
is_training=is_training)
class GeluPrimitive(BasePrimitive):
"""
Gelu Froward Primitive
"""
name = "te_gelu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(x_aval):
"""
gated_gelu abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, x):
"""
gated_gelu lowering rules
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
in_dtype)
out = custom_caller(GeluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(x):
assert GeluPrimitive.inner_primitive is not None
out = GeluPrimitive.inner_primitive.bind(x)
return out
@staticmethod
def batcher(batched_args, batch_dims):
"""
gated_gelu batcher
"""
_check_valid_batch_dims(batch_dims)
assert GeluPrimitive.outer_primitive is not None
inputs, = batched_args
inputs_bdim, = batch_dims
out_bdims = inputs_bdim
return GeluPrimitive.outer_primitive.bind(inputs), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
gated_gelu infer_sharding_from_operands
"""
del result_infos # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
return out_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
gated_gelu partitioning
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
impl = GeluPrimitive.impl
return mesh, impl, out_sharding, arg_shardings
register_primitive(GeluPrimitive)
def gelu(inputs: jnp.ndarray) -> jnp.ndarray:
"""
gelu wrapper
Return geglu(inputs)
Assume inputs has two dimensions shape and the memory layout is (N..., H)
"""
return GeluPrimitive.outer_primitive.bind(inputs)
class DGeluPrimitive(BasePrimitive):
"""
Dgated Gelu Primitive
"""
name = "te_dgelu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(dz_aval, x_aval):
"""
dgelu abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert dz_aval.shape == x_aval.shape
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, dz, x):
"""
dgelu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
assert ir_in_shape == gi_shape
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype)
out = custom_caller(DGeluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(dz, x):
"""
dgelu implementation
"""
assert DGeluPrimitive.inner_primitive is not None
dx = DGeluPrimitive.inner_primitive.bind(dz, x)
return dx
@staticmethod
def batcher(batched_args, batch_dims):
"""
dgelu batcher
"""
_check_valid_batch_dims(batch_dims)
assert DGeluPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DGeluPrimitive.outer_primitive.bind(dz, x), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
dgelu infer_sharding_from_operands
"""
del result_infos # Unused.
gelu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec))
return dx_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
dgelu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DGeluPrimitive.impl
return mesh, impl, out_shardings, arg_shardings
register_primitive(DGeluPrimitive)
def dgelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
"""
dgelu fusion wrapper
Return dgeglu(inputs)
"""
return DGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs)
class GatedGeluPrimitive(BasePrimitive):
"""
Gated Gelu Froward Primitive
......@@ -3190,7 +3402,7 @@ class CastFP8Primitive(BasePrimitive):
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, *_ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim
out_bdims = x_bdim, amax_bdim
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype), out_bdims
......@@ -3386,13 +3598,14 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert gamma_aval.size == beta_aval.size
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in type
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in type
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type
jax_dtype_to_te_dtype(out_dtype),
True, zero_centered_gamma, epsilon
)
True,
zero_centered_gamma,
epsilon)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
......@@ -3477,14 +3690,14 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma,
epsilon,
sm_margin,
......@@ -3636,13 +3849,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
rsigama_dtype = jnp.float32
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch_size
x_aval.size // hidden_size, # batch_size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(out_dtype), # out te_dtype
False, False, epsilon
)
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(out_dtype), # out te_dtype
False,
False,
epsilon)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
......@@ -3716,14 +3930,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
......@@ -3832,6 +4046,379 @@ def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale
epsilon=epsilon)
class GeluFp8Primitive(BasePrimitive):
"""
Gelu FP8 Primitive
"""
name = "te_gelu_fp8"
multiple_results = True
impl_static_args = (4,) #out_dtype
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
"""
te_gelu_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
"""
te_gated_gelu_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(GeluFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
"""
to describe implementation
"""
assert GeluFp8Primitive.inner_primitive is not None
out, updated_amax = GeluFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
return out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
"""
to describe batch rules for vmap
"""
_check_valid_batch_dims(batch_dims)
assert GeluFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim
return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = GeluFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(GeluFp8Primitive)
def gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
gated gelu wrapper
Return FP8(geglu(x))
"""
return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
class DGeluDBiasCastTransposePrimitive(BasePrimitive):
"""
DGelu DBias Cast Transpose Primitive
"""
name = "te_dgelu_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (5, 6, 7)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, transpose_axis_boundary):
"""
te_dgelu_dbais_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
dbias_shape = (*x_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_info, = transformer_engine_jax.get_dgelu_dbias_ct_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = x_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dgelu_dbais_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = \
DGeluDBiasCastTransposePrimitive.abstract(*args, **kwargs)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
te_dgated_gelu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
assert ir_dz_shape == x_shape
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (batch_szie, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary,
transpose_axis_boundary)
dbias_shape = (*x_shape[:static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype))
out = custom_caller(DGeluDBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 3})
return out
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe implementation
"""
assert DGeluDBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DGeluDBiasCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DGeluDBiasCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return DGeluDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DGeluDBiasCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DGeluDBiasCastTransposePrimitive)
def dgelu_dbias_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dgelu and dbias fusion wrapper
Return FP8(dgeglu(inputs)), dbias
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
return DGeluDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
class GatedGeluFp8Primitive(BasePrimitive):
"""
Gated Gelu FP8 Primitive
......
......@@ -25,6 +25,10 @@ pybind11::dict Registrations() {
pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_gelu"] = EncapsulateFunction(Gelu);
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8);
dict["te_dgelu"] = EncapsulateFunction(DGelu);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
......@@ -55,6 +59,7 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
......@@ -62,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
......
......@@ -61,18 +61,29 @@ pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape,
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype) {
CustomCallCommonWkDescriptor desc;
desc.shape.from_vector(shape);
desc.wkshape.from_vector(wkshape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.wk_dtype = wk_dtype;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype,
DType wkspace_dtype, DType barrier_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype,
bool zero_centered_gamma, float eps, int sm_margin) {
return PackOpaque(CustomCallNormDescriptor{batch_size, hidden_size, wkspace_size, barrier_size,
dgamma_part_sizes, dbeta_part_sizes,
x_dtype, w_dtype, wkspace_dtype, barrier_dtype,
dgamma_part_dtype, dbeta_part_dtype,
zero_centered_gamma, eps, sm_margin});
DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma,
float eps, int sm_margin) {
return PackOpaque(CustomCallNormDescriptor{
batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, dbeta_part_sizes,
x_dtype, w_dtype, wkspace_dtype, barrier_dtype, dgamma_part_dtype, dbeta_part_dtype,
zero_centered_gamma, eps, sm_margin});
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
......@@ -83,11 +94,10 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training) {
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
DType wkspace_dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, dtype, wkspace_dtype,
......@@ -149,6 +159,138 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
input_cast_trans_tensor.data(), stream);
}
void GeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
}
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output);
}
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dgelu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}
pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto gelu_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gelu_input_tensor = TensorWrapper(nullptr, gelu_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
}
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2};
......@@ -251,10 +393,10 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op
output_trans_tensor.data(), stream);
}
pybind11::tuple GetLayerNormForwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps
) {
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
......@@ -289,13 +431,12 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
}
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size,
size_t workspace_size, size_t barrier_size,
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *bias, void *output, DType out_dtype,
void *workspace, DType work_dtype, void *barrier, DType barrier_dtype,
void *mu, void *rsigma, float *amax, float *scale, float *scale_inv,
cudaStream_t stream) {
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size,
size_t barrier_size, bool zero_centered_gamma, float eps, void *input,
DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
DType out_dtype, void *workspace, DType work_dtype, void *barrier,
DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale,
float *scale_inv, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
......@@ -333,10 +474,10 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size,
}
}
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm,
bool zero_centered_gamma, float eps
) {
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size};
......@@ -373,8 +514,8 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), nullptr,
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
dummy_dgamma_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
dummy_barrier_tensor.data());
dbeta_part_shape = std::vector<size_t>{0, 0};
}
......@@ -388,15 +529,13 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
}
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
bool zero_centered_gamma, float eps,
void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd,
void *workspace, DType wkspace_dtype, void *barrier, DType barrier_dtype,
void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta,
void *dgamma_part, DType dgamma_dtype,
void* dbeta_part, DType dbeta_dtype,
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
size_t barrier_size, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *ograd, void *workspace,
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part,
DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype,
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size};
......@@ -479,10 +618,10 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -513,10 +652,10 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -551,11 +690,11 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps,
input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype,
barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta,
dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
}
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -588,10 +727,10 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -622,10 +761,10 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -659,11 +798,11 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size,
dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps,
input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype,
barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta,
dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
}
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -741,8 +880,8 @@ void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape = std::vector<size_t>{desc.batch_size, desc.head_dim,
desc.q_seqlen, desc.k_seqlen};
auto io_shape =
std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
......@@ -818,11 +957,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
- common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
- common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
void PrepareFusedAttnForwardAuxTensors(
NVTETensorPack *tensor_pack, const CustomCallFusedAttnDescriptor *desc,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr, void *bias_buf = nullptr
) {
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
const CustomCallFusedAttnDescriptor *desc,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr,
void *bias_buf = nullptr) {
auto batch_size = desc->batch_size;
auto num_heads = desc->num_heads;
auto q_max_seqlen = desc->q_max_seqlen;
......@@ -833,8 +972,8 @@ void PrepareFusedAttnForwardAuxTensors(
tensor_pack->size = 1;
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.dptr = softmax_buf;
softmax_aux->data.shape = std::vector<size_t>{
batch_size, num_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.shape =
std::vector<size_t>{batch_size, num_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.dtype = desc->dtype;
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
......@@ -867,10 +1006,10 @@ void PrepareFusedAttnForwardAuxTensors(
TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
void PrepareFusedAttnBackwardAuxTensors(
NVTETensorPack* tensor_pack, const CustomCallFusedAttnDescriptor *desc,
NVTE_Fused_Attn_Backend backend, void* softmax_buf, void* rng_state_buf, void* bias_buf
) {
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
const CustomCallFusedAttnDescriptor *desc,
NVTE_Fused_Attn_Backend backend, void *softmax_buf,
void *rng_state_buf, void *bias_buf) {
// Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path
auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
......@@ -880,17 +1019,16 @@ void PrepareFusedAttnBackwardAuxTensors(
// correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor* softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux->data.dtype = desc->dtype;
}
}
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
......@@ -898,17 +1036,16 @@ pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto o_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype);
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim);
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
......@@ -916,9 +1053,9 @@ pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
rng_state_tensor.data(), max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
......@@ -957,8 +1094,8 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper(
cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
......@@ -969,9 +1106,8 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim);
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
......@@ -995,10 +1131,9 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
}
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
......@@ -1014,8 +1149,8 @@ pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1},
DType::kInt32);
auto cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
......@@ -1084,11 +1219,10 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend,
softmax_aux, rng_state, bias);
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
......@@ -1107,11 +1241,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
}
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
......@@ -1128,10 +1260,10 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
......@@ -1139,12 +1271,12 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
......@@ -1203,9 +1335,9 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, head_dim);
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
......@@ -1215,25 +1347,23 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
softmax_aux);
// cuDNN workspace
auto workspace_tensor = TensorWrapper(
workspace, std::vector<size_t>{descriptor.wkspace_size}, descriptor.wkspace_dtype);
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
......@@ -1252,10 +1382,10 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
......@@ -1267,8 +1397,8 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for FP16/BF16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
......@@ -1325,21 +1455,21 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper(
q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper(
kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend,
softmax_aux, rng_state, bias);
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
......@@ -1352,8 +1482,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
s_tensor.data(), // not used for FP16/BF16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
......
......@@ -52,6 +52,18 @@ struct CustomCallCommonDescriptor {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype);
struct CustomCallCommonWkDescriptor {
Shape shape;
Shape wkshape;
DType in_dtype;
DType out_dtype;
DType wk_dtype;
};
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype);
struct CustomCallNormDescriptor {
size_t batch_size;
size_t hidden_size;
......@@ -73,10 +85,10 @@ struct CustomCallNormDescriptor {
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype,
DType wkspace_dtype, DType barrier_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype,
bool zero_centered_gamma, float eps, int sm_margin);
DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma,
float eps, int sm_margin);
struct SoftmaxDescriptor {
size_t batch_size;
......@@ -110,11 +122,10 @@ struct CustomCallFusedAttnDescriptor {
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training);
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
DType wkspace_dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -127,6 +138,18 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -136,20 +159,20 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps
);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm,
bool zero_centered_gamma, float eps
);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -182,39 +205,33 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_heads, size_t num_gqa_groups, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
......
......@@ -23,8 +23,10 @@ from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernorm_geglu_fp8_mlp, geglu
from ..mlp import layernorm_gelu_fp8_mlp, gelu
from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -502,6 +504,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
If set False, return None as the second tensor in outputs.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters
-----------------------
......@@ -534,6 +544,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None
sharding_type = None
......@@ -571,6 +583,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
) and not self.return_layernorm_output and self.enable_layernorm
if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1]
......@@ -626,8 +640,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = type_safe_dot_general(y,
kernel,
fp8_meta_pkg=fp8_meta_package,
......@@ -730,6 +747,18 @@ class LayerNormMLP(TransformerEngineBase):
Dimensions that will share the same dropout mask for hidden
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_1_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 1st dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_2_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters
-----------------------
......@@ -765,6 +794,9 @@ class LayerNormMLP(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None
major_sharding_type = None
def __post_init__(self):
......@@ -812,13 +844,28 @@ class LayerNormMLP(TransformerEngineBase):
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_mlp = fuse_layernorm \
def is_gelu(acts):
geglu_act_pool = [('gelu',)]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_geglu_mlp = fuse_layernorm \
and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3)
use_fused_ln_gelu_mlp = fuse_layernorm \
and self.use_bias and is_gelu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3)
# LayerNorm
if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1]
......@@ -883,7 +930,10 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
contract_ind = tuple(range(0, len(axis)))
if use_fused_ln_mlp:
ffn1_ckpt_name = 'ffn1'
ffn2_ckpt_name = 'ffn2'
if use_fused_ln_geglu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_geglu_fp8_mlp(y,
......@@ -892,8 +942,41 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
else: # not use_fused_ln_mlp
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
elif use_fused_ln_gelu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
out = layernorm_gelu_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1
gemm1_fp8_meta_package = None if fp8_meta_package is None \
......@@ -906,8 +989,11 @@ class LayerNormMLP(TransformerEngineBase):
gemm1_fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon)
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
x = type_safe_dot_general(y,
kernel_1,
fp8_meta_pkg=gemm1_fp8_meta_package,
......@@ -924,11 +1010,14 @@ class LayerNormMLP(TransformerEngineBase):
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)
x = checkpoint_name(x, 'ffn1')
x = checkpoint_name(x, ffn1_ckpt_name)
activations = []
if is_geglu(self.activations):
z = geglu(x)
elif is_gelu(self.activations):
z = gelu(x)
z = jnp.reshape(z, (*z.shape[:-2], -1))
else:
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations):
......@@ -942,6 +1031,8 @@ class LayerNormMLP(TransformerEngineBase):
rng_collection=self.intermediate_dropout_rng_name)(
z, deterministic=deterministic)
z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
# DenseGeneral 2
gemm2_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(1)
......@@ -960,6 +1051,6 @@ class LayerNormMLP(TransformerEngineBase):
bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, 'ffn2')
out = checkpoint_name(out, ffn2_ckpt_name)
return out, ln_output # Output, layner_norm_output
......@@ -27,8 +27,12 @@ from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType
from ..sharding import global_mesh_resource, num_of_devices
from ..sharding import with_sharding_constraint
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -39,17 +43,6 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'
def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path.
......@@ -101,36 +94,8 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
else:
rules_map[key] = [val]
gsr = global_mesh_resource()
batch_dim_rule = []
if gsr.dp_resource is not None:
batch_dim_rule.append(gsr.dp_resource)
if gsr.fsdp_resource is not None and gsr.dp_resource != gsr.fsdp_resource:
batch_dim_rule.append(gsr.fsdp_resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
te_logical_axis_rules = (
(BATCH_AXES, batch_dim_rule),
(SEQLEN_AXES, None),
(HEAD_AXES, gsr.tp_resource),
(HIDDEN_AXES, None),
(HIDDEN_TP_AXES, gsr.tp_resource),
(JOINED_AXES, None),
(W_NO_SHARD_AXES, None),
(W_FSDP_AXES, gsr.fsdp_resource),
(W_TP_AXES, gsr.tp_resource),
(W_JOINED_AXES, None),
)
extended_rules = [*rules]
for item in te_logical_axis_rules:
for item in get_sharding_map_logic_axis_to_mesh_axis().items():
key = item[0]
val = item[1]
if key in rules_map:
......@@ -143,18 +108,6 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
return tuple(extended_rules)
def _with_sharding_constraint(x: Array, logical_axis_names: Shape):
assert len(x.shape) == len(logical_axis_names)
rules = extend_logical_axis_rules(tuple())
rules_dict = {}
for key, value in rules:
rules_dict[key] = value
mesh_axis_names = [rules_dict[name] for name in logical_axis_names]
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return with_sharding_constraint(x, pspec)
def _merge_mask(func, *masks: Optional[Array]):
masks = [m for m in masks if m is not None]
if not masks:
......@@ -175,7 +128,10 @@ def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases."""
func = lambda a, b: a + b
def func(a, b):
return a + b
return _merge_mask(func, *masks)
......@@ -234,8 +190,8 @@ def core_attention(query: Array,
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
attn_weights = with_sharding_constraint_by_logical_axes(
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
# When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias).
# In this case, the scale can not fused into the Softmax module.
......@@ -270,6 +226,39 @@ def core_attention(query: Array,
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: bool):
"""
Rotary Positional Embedding
x should be in shape of
[Batch, Seqlen, ..., Hidden] if transpose_batch_sequence is False, or
[Seqlen, Batch, ..., Hidden] if transpose_batch_sequence is True.
"""
embed_dim = x.shape[-1]
half_embed_dim = embed_dim // 2
min_window = windows[0]
max_window = windows[1]
fraction = 2 * jnp.arange(0, half_embed_dim) / embed_dim
time_scales = min_window * (max_window / min_window)**fraction
time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))
batch_dim = 1 if transpose_batch_sequence else 0
seq_dim = 1 - batch_dim
positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))
sinusoidal_positions = positions / time_scales
sin = jnp.sin(sinusoidal_positions)
cos = jnp.cos(sinusoidal_positions)
x1, x2 = jnp.split(x, 2, axis=-1)
part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
part_2 = (x2 * cos + x1 * sin).astype(x.dtype)
return jnp.concatenate([part_1, part_2], axis=-1)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
......@@ -331,6 +320,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation.
Introduced in v0.10.0.
enable_rotary_pos_emb: bool, default = False
Whether to enable rotary position embedding to projected query and key.
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
Indicate the min and max time-scales of rotary position embedding,
only used when :attr:`enable_rotary_pos_emb=True`
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
Optimization parameters
-----------------------
......@@ -368,9 +364,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_mask_type: str = 'causal'
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
dtype: DType = jnp.float32
fuse_qkv: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False # computes logits in float32 for stability.
......@@ -501,6 +500,22 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention.")
def generate_batch_seqlen_logical_axes(is_sharded_seq):
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
axes = [None, None]
axes[batch_dim] = BATCH_AXES
axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
return tuple(axes)
inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes(
self.enable_sequence_parallel), HIDDEN_AXES)
inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)
inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)
residual = inputs_q
if self.fuse_qkv:
if is_qkvpack:
......@@ -520,6 +535,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(W_JOINED_AXES, W_TP_AXES),
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='qkv',
dtype=self.dtype)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
......@@ -544,6 +561,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_TP_AXES,),
dtype=self.dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='query')(inputs_q)
if is_self_attn:
......@@ -591,6 +610,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_TP_AXES,),
dtype=self.dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='query')(inputs_q)
if is_self_attn:
......@@ -604,6 +625,23 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
assert ln_out is not None
residual = ln_out
if self.enable_rotary_pos_emb:
if self.fuse_qkv and use_fused_attn:
if is_qkvpack:
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
else:
key, value = jnp.split(kv_proj, [1], axis=-2)
query = rotary_pos_emb(query, self.rotary_pos_emb_windows,
self.transpose_batch_sequence)
key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence)
if use_fused_attn:
if is_qkvpack:
qkv_proj = jnp.concatenate([query, key, value], axis=-2)
else:
kv_proj = jnp.concatenate([key, value], axis=-2)
if not use_fused_attn:
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
......@@ -615,9 +653,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
(SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \
else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
query = _with_sharding_constraint(query, qkv_sharding_constraint)
key = _with_sharding_constraint(key, qkv_sharding_constraint)
value = _with_sharding_constraint(value, qkv_sharding_constraint)
query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
if decode:
is_initialized = self.has_variable('cache', 'cached_key')
......@@ -679,7 +717,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
qkv_proj = _with_sharding_constraint(qkv_proj, qkv_sharding_constraint)
qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj,
qkv_sharding_constraint)
x = self_fused_attn(qkv_proj,
bias,
mask,
......@@ -696,8 +736,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
query = _with_sharding_constraint(query, q_sharding_constraint)
kv_proj = _with_sharding_constraint(kv_proj, kv_sharding_constraint)
query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
x = cross_fused_attn(query,
kv_proj,
......@@ -748,7 +788,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
(SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \
if self.transpose_batch_sequence \
else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
x = _with_sharding_constraint(x, attn_context_sharding_constraint)
x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1],
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -967,6 +1007,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_attention_heads=self.num_attention_heads, dtype=self.dtype,
embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
enable_rotary_pos_emb: bool, default = False
Whether to enable rotary position embedding to projected query and key in MHA.
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
Indicate the min and max time-scales of rotary position embedding,
only used when :attr:`enable_rotary_pos_emb=True`
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
Optimization parameters
-----------------------
......@@ -1016,10 +1063,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
self_attn_mask_type: str = 'causal'
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
dtype: DType = jnp.float32
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
......@@ -1089,6 +1139,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
def generate_batch_seqlen_logical_axes(is_shared_seq=None):
axes = [None, None]
is_shared_seq = self.enable_sequence_parallel if is_shared_seq is None \
else is_shared_seq
axes[batch_dim] = BATCH_AXES
axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
return tuple(axes)
attn_bias = None
if self.enable_relative_embedding:
if self.relative_embedding is None:
......@@ -1120,7 +1180,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
else:
mha_name = 'self_attention'
inputs = _with_sharding_constraint(inputs, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))
inputs = with_sharding_constraint_by_logical_axes(
inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x, residual = MultiHeadAttention(
......@@ -1129,6 +1190,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_attention_logits,
......@@ -1140,6 +1202,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_mask_type=self.self_attn_mask_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
fuse_qkv=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
......@@ -1161,6 +1225,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
broadcast_dims=self.hidden_dropout_dims,
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
x = with_sharding_constraint_by_logical_axes(
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
residual = with_sharding_constraint_by_logical_axes(
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
x = hidden_dropout(x, deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
......@@ -1174,12 +1243,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
assert encoded is not None, \
"encoded is required when layer_type == TransformerLayerType.DECODER."
x = with_sharding_constraint_by_logical_axes(
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
y, residual = MultiHeadAttention(
num_heads=self.num_attention_heads,
dtype=self.dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type,
......@@ -1189,6 +1262,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
apply_residual_connection_post_layernorm,
output_layernorm=False, # Must do LayerNorm before MHA.
attn_mask_type='padding',
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
......@@ -1200,10 +1275,17 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
encoded,
encoder_decoder_mask,
deterministic=deterministic)
y = with_sharding_constraint_by_logical_axes(
y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
residual = with_sharding_constraint_by_logical_axes(
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
y = hidden_dropout(y, deterministic)
mlp_input = y + residual
mlp_input = _with_sharding_constraint(mlp_input, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))
mlp_input = with_sharding_constraint_by_logical_axes(
mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
# MlpBlock
residual = mlp_input
......@@ -1228,6 +1310,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
name='mlp',
)(mlp_input, deterministic=deterministic)
......@@ -1235,6 +1320,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
assert ln_out is not None
residual = ln_out
z = with_sharding_constraint_by_logical_axes(
z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
residual = with_sharding_constraint_by_logical_axes(
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
z = hidden_dropout(z, deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
......@@ -1243,6 +1333,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
z = z + residual
if self.output_layernorm:
z = with_sharding_constraint_by_logical_axes(
z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
......
......@@ -4,6 +4,8 @@
"""JAX layernorm modules"""
from functools import partial
from typing import Tuple
import jax
import jax.numpy as jnp
......@@ -12,6 +14,7 @@ from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
def canonicalize_layernorm_type(x):
......@@ -96,14 +99,20 @@ def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
def layernorm_fp8_dot(x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
fp8_meta_pkg: FP8MetaPackage,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6) -> jnp.ndarray:
def layernorm_fp8_dot(
x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
fp8_meta_pkg: FP8MetaPackage,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[
str, ...] = None, # The logic axes of sharding constraint to the layernorm input.
dot_input_axes: Tuple[str,
...] = None # The logic axes of sharding constraint to the dot input.
) -> jnp.ndarray:
"""
Layernorm + FP8 GEMM
"""
......@@ -114,18 +123,21 @@ def layernorm_fp8_dot(x: jnp.ndarray,
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon)
layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_input_axes)
return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12))
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14))
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float):
bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...]):
output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype,
zero_centered_gamma, epsilon)
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_input_axes)
return output
......@@ -142,7 +154,9 @@ def _layernorm_fp8_dot_fwd_rule(
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
zero_centered_gamma,
epsilon):
epsilon,
layernorm_input_axes,
dot_input_axes):
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
......@@ -156,6 +170,8 @@ def _layernorm_fp8_dot_fwd_rule(
x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x,
......@@ -191,6 +207,8 @@ def _layernorm_fp8_dot_fwd_rule(
casted_kernel, updated_kernel_amax = \
cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)
# (batch..., hidden_in) x (hidden_in, hidden_out...)
output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
(x_contracting_dims, k_contracting_dims),
......@@ -209,6 +227,8 @@ def _layernorm_fp8_dot_bwd_rule(
bwd_dtype,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
ctx,
grad):
ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
......@@ -243,6 +263,7 @@ def _layernorm_fp8_dot_bwd_rule(
(g_for_dgrad_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad,
x,
......
......@@ -3,13 +3,16 @@
# See LICENSE for license information.
"""JAX MLP modules"""
from typing import List
from typing import List, Tuple
from functools import partial
import jax
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from .cpp_extensions import cast_fp8, transpose, cast_transpose
from .cpp_extensions import gelu as te_gelu
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
......@@ -17,6 +20,40 @@ from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
def gelu(x: jnp.ndarray):
"""
Gelu
"""
output = _gelu(x)
return output
@partial(jax.custom_vjp)
def _gelu(x: jnp.ndarray):
geglu_output, _ = _gelu_fwd_rule(x)
return geglu_output
def _gelu_fwd_rule(x):
geglu_output = te_gelu(x)
return geglu_output, (x,)
def _gelu_bwd_rule(ctx, g):
x, = ctx
assert x.dtype == g.dtype
dx = dgelu(g, x)
dx = jnp.reshape(dx, x.shape)
return (dx,)
_gelu.defvjp(_gelu_fwd_rule, _gelu_bwd_rule)
def geglu(x: jnp.ndarray):
......@@ -47,9 +84,9 @@ def _geglu_bwd_rule(ctx, g):
x, = ctx
assert x.dtype == g.dtype
dgelu = dgated_gelu(g, x)
dgelu = jnp.reshape(dgelu, x.shape)
return (dgelu,)
dx = dgated_gelu(g, x)
dx = jnp.reshape(dx, x.shape)
return (dx,)
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
......@@ -62,7 +99,12 @@ def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
fp8_gemm_pkg: FP8MetaPackage,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6) -> jnp.ndarray:
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = 'ffn1',
ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray:
"""
Layernorm + GEMM1 + GeGLU + GEMM2
"""
......@@ -88,19 +130,26 @@ def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon)
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
ffn2_ckpt_name)
return output
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13))
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
zero_centered_gamma: bool, epsilon: float):
zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str):
output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
scale, scale_inv, fwd_dtype, bwd_dtype,
layernorm_type, zero_centered_gamma, epsilon)
layernorm_type, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name)
return output
......@@ -118,7 +167,12 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
bwd_dtype, # pylint: disable=unused-argument
layernorm_type,
zero_centered_gamma,
epsilon):
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name):
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out)
......@@ -141,6 +195,8 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
x_scale = scale[gemm1_x_idx]
x_scale_inv = scale_inv[gemm1_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x,
......@@ -175,10 +231,13 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
casted_kernel_1, updated_kernel_1_amax = \
cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)
# (batch..., hidden_in) x (hidden_in, 2, hidden_out)
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
......@@ -191,6 +250,8 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
geglu_out_scale, geglu_out_scale_inv,
fwd_dtype)
casted_geglu_out = with_sharding_constraint_by_logical_axes(casted_geglu_out, dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
......@@ -201,6 +262,7 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
......@@ -215,6 +277,11 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
ctx,
grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \
......@@ -228,6 +295,9 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx]
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-1)
......@@ -248,6 +318,8 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgeglu_amax = amax[gemm1_grad_idx, 0:1]
......@@ -280,6 +352,8 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
x,
......@@ -309,3 +383,315 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
_layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule,
_layernorm_geglu_fp8_mlp_bwd_rule)
def layernorm_gelu_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
biases: List[jnp.ndarray],
fp8_gemm_pkg: FP8MetaPackage,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = 'ffn1',
ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray:
"""
Layernorm + GEMM1 + bias + GeLU + GEMM2 + bias
"""
assert len(kernels) == 2
assert fp8_gemm_pkg.num_of_gemm == len(kernels)
kernel_1 = kernels[0]
kernel_2 = kernels[1]
bias_1 = biases[0]
bias_2 = biases[1]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output = _layernorm_gelu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
ffn2_ckpt_name)
return output
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
def _layernorm_gelu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
epsilon: float, layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str):
output, _ = _layernorm_gelu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2,
fp8_max, amax, scale, scale_inv, fwd_dtype,
bwd_dtype, layernorm_type, zero_centered_gamma,
epsilon, layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name)
return output
def _layernorm_gelu_fp8_mlp_fwd_rule(
x,
gamma,
beta,
kernel_1,
kernel_2,
bias_1,
bias_2,
fp8_max,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name):
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out)
assert len(kernel_1.shape) == 3
assert kernel_1.shape[-2] == 1
assert len(kernel_2.shape) == 2
x_contracting_dims = (len(x.shape) - 1,)
xt_batch_dims = tuple(range(1, x.ndim))
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0]
# Squeeze act axis
# (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
kernel_1 = jnp.squeeze(kernel_1, axis=-2)
amax = FP8Helper.update_amax_history(amax)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm1_x_idx, 0:1]
x_scale = scale[gemm1_x_idx]
x_scale_inv = scale_inv[gemm1_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x,
gamma,
beta,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
gamma,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
epsilon=epsilon)
mu = None
assert x.shape == ln_out.shape
kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
kernel_1_scale = scale[gemm1_kernel_idx]
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_1, updated_kernel_1_amax = \
cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape
dot_1_output += jnp.reshape(bias_1, bias_1_shape)
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
gelu_out_amax = amax[gemm2_x_idx, 0:1]
gelu_out_scale = scale[gemm2_x_idx]
gelu_out_scale_inv = scale_inv[gemm2_x_idx]
# (batch..., hidden_in) -> (batch..., hidden)
casted_gelu_out, updated_gelu_amax = gelu_fp8(dot_1_output, gelu_out_amax, gelu_out_scale,
gelu_out_scale_inv, fwd_dtype)
casted_gelu_out = with_sharding_constraint_by_logical_axes(casted_gelu_out, dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_gelu_out, casted_kernel_2, gelu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape
dot_2_output += jnp.reshape(bias_2, bias_2_shape)
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_gelu_amax,
updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims,
bias_1.shape, bias_2.shape)
return dot_2_output, ctx
def _layernorm_gelu_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
ctx,
grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, \
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1]
grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx]
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-1)
casted_gelu_out_t = transpose(casted_gelu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_2 = jnp.sum(grad, axis=(i for i in range(grad.ndim - 1)))
dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
wgrad_2 = fp8_dot_impl(casted_gelu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgelu_amax = amax[gemm1_grad_idx, 0:1]
dgelu_scale = scale[gemm1_grad_idx]
dgelu_scale_inv = scale_inv[gemm1_grad_idx]
casted_dgelu, casted_dgelu_t, dbias_1, updated_dgelu_amax = dgelu_dbias_cast_transpose(
dgrad_2,
dot_1_output,
dgelu_amax,
dgelu_scale,
dgelu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgelu_t, gemm1_x_scale_inv, dgelu_scale_inv, grad.dtype,
(xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# Expand act axis to match the shape with the given kernel_1
wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dgelu, casted_kernel_1, dgelu_scale_inv, kernel_1_scale_inv,
grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(updated_dgelu_amax[0])
amax = amax.at[gemm2_x_idx, 0].set(updated_gelu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
fp8_max, amax, scale, scale_inv
_layernorm_gelu_fp8_mlp.defvjp(_layernorm_gelu_fp8_mlp_fwd_rule, _layernorm_gelu_fp8_mlp_bwd_rule)
......@@ -77,6 +77,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
attn_mask_type: str = 'causal'
fuse_qkv: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
......@@ -109,6 +110,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
attn_mask_type=self.attn_mask_type,
fuse_qkv=self.fuse_qkv,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits)
......@@ -156,11 +158,14 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal'
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
......@@ -221,11 +226,14 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init)
......
......@@ -4,7 +4,7 @@
"""
Sharding Meta for xmap with CustomCall
"""
import os
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
......@@ -16,6 +16,19 @@ from jax.sharding import PartitionSpec
_PXLA_THREAD_RESOURCES = pxla.thread_resources
# Axis Names
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
SEQLEN_TP_AXES = 'nvte_seqlen_tp'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'
def _get_mesh_info(resource: str):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
......@@ -24,6 +37,81 @@ def _get_mesh_info(resource: str):
return mesh.shape[resource], resource
def get_sharding_map_logic_axis_to_mesh_axis():
"""
Generate a dict to map logical axes to mesh axes.
"""
gsr = global_mesh_resource()
IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))
batch_resources = [gsr.fsdp_resource, gsr.dp_resource] if IS_FSDP_OUTER \
else [gsr.dp_resource, gsr.fsdp_resource]
batch_dim_rule = []
for resource in batch_resources:
if resource is not None and resource not in batch_dim_rule:
batch_dim_rule.append(resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
te_logical_axis_to_mesh_axis = {
BATCH_AXES: batch_dim_rule,
SEQLEN_AXES: None,
SEQLEN_TP_AXES: gsr.tp_resource,
HEAD_AXES: gsr.tp_resource,
HIDDEN_AXES: None,
HIDDEN_TP_AXES: gsr.tp_resource,
JOINED_AXES: None,
W_NO_SHARD_AXES: None,
W_FSDP_AXES: gsr.fsdp_resource,
W_TP_AXES: gsr.tp_resource,
W_JOINED_AXES: None,
}
return te_logical_axis_to_mesh_axis
def generate_pspec(logical_axis_names):
"""
Convert logical axes to PartitionSpec
"""
rules = get_sharding_map_logic_axis_to_mesh_axis()
mesh_axis_names = [rules[name] for name in logical_axis_names]
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return pspec
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
if pspec is None:
return x
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
"""
if logical_axis_names is None:
return x
assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec)
def get_all_mesh_axes():
"""
Get all name of mesh axes
......@@ -42,17 +130,6 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec))
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
"""
A wrapper function to invoke lax.p* operations, like psum.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment