Unverified Commit ce163f9e authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support SP + RoPE + GeLU (#602)



* Adding support of sequence parallelism
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding RoPE
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix wrong batch_logical_axes
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Rnaming FSDP outer env var
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Poring RoPE to Praxis layers.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Porting GeLU + [FP8 Cast].
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Allowing arbitrary dimension of NVShape for the workspace allocation
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Modify with review feedback.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix bugs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed for lint
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Follow review feedback to modify code.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Port SP to Praxis
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix an issue when enabling both GQA and RoPE.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update docs
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
parent 29b0c9ca
...@@ -6,10 +6,23 @@ ...@@ -6,10 +6,23 @@
Jax Jax
======= =======
.. autoapiclass:: transformer_engine.jax.MajorShardingType Pre-defined Variable of Logical Axes
.. autoapiclass:: transformer_engine.jax.ShardingType ------------------------------------
Variables are available in `transformer_engine.jax.sharding`.
* BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh.
* SEQLEN_AXES: The logical axis of sequence length dimension. It is usually not sharded.
* SEQLEN_TP_AXES: The logical axis of sequence length dimension. It is usually sharded along TP on Mesh.
* HEAD_AXES: The logical axis of head dimension of MHA. It is usually sharded along TP on Mesh.
* HIDDEN_AXES: The logical axis of hidden dimension. It is usually not sharded.
* HIDDEN_TP_AXES: The logical axis of hidden dimension. It is usually sharded along TP on Mesh.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.
Modules
------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType .. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None) .. autoapiclass:: transformer_engine.jax.MeshResource()
.. autoapifunction:: transformer_engine.jax.fp8_autocast .. autoapifunction:: transformer_engine.jax.fp8_autocast
......
...@@ -35,6 +35,7 @@ INPUT_KEY = 'input_rng' ...@@ -35,6 +35,7 @@ INPUT_KEY = 'input_rng'
class Net(nn.Module): class Net(nn.Module):
"""NLP Encoder""" """NLP Encoder"""
num_embed: int num_embed: int
enable_seq_paral: bool
@nn.compact @nn.compact
def __call__(self, x, mask, disable_dropout=False): def __call__(self, x, mask, disable_dropout=False):
...@@ -50,11 +51,17 @@ class Net(nn.Module): ...@@ -50,11 +51,17 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding', self_attn_mask_type='padding',
enable_relative_embedding=False, enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
x = jax.lax.with_sharding_constraint(x,
jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None))
x = te_flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
...@@ -266,7 +273,7 @@ def train_and_evaluate(args): ...@@ -266,7 +273,7 @@ def train_and_evaluate(args):
with te.fp8_autocast(args.use_fp8, with te.fp8_autocast(args.use_fp8,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)): None)):
encoder = Net(num_embed) encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
...@@ -379,6 +386,10 @@ def encoder_parser(args): ...@@ -379,6 +386,10 @@ def encoder_parser(args):
action="store_true", action="store_true",
default=False, default=False,
help="Use FP8 for inference and training without recalibration") help="Use FP8 for inference and training without recalibration")
parser.add_argument("--enable-sp",
action="store_true",
default=False,
help="Enable sequence parallelism.")
return parser.parse_args(args) return parser.parse_args(args)
...@@ -405,6 +416,20 @@ class TestEncoder(unittest.TestCase): ...@@ -405,6 +416,20 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79 assert actual[0] < 0.45 and actual[1] > 0.79
def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_sp(self):
"""Test Transformer Engine with FP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -14,6 +14,8 @@ from jax import jit, value_and_grad ...@@ -14,6 +14,8 @@ from jax import jit, value_and_grad
from flax import linen as nn from flax import linen as nn
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import dgelu, dgelu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import gelu, gelu_fp8
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8 from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
...@@ -21,6 +23,7 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper ...@@ -21,6 +23,7 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp
GEMM_CASES = [ GEMM_CASES = [
(256, 256, 512), (256, 256, 512),
...@@ -285,6 +288,126 @@ class TestFP8Dot: ...@@ -285,6 +288,126 @@ class TestFP8Dot:
jnp.asarray(ref_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_gelu_fp8_mlp(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
activations = ('gelu',)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
(FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(
layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm"))
def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)
fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
amax[:FP8Helper.NUM_META_PER_GEMM],
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = jax.nn.gelu(linear_1_out)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
amax[FP8Helper.NUM_META_PER_GEMM:],
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))
bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)
return output
def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
return jnp.mean(
ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))
value_n_grad_primitive_func = jit(
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
ref_fp8_max = init_fp8_max
ref_fp8_metas_amax = init_fp8_metas_amax
ref_fp8_metas_scale = init_fp8_metas_scale
ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv
pri_fp8_max = init_fp8_max
pri_fp8_metas_amax = init_fp8_metas_amax
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
ref_fp8_metas_scale, ref_fp8_metas_scale_inv)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, pri_fp8_max,
pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, s, k1, k2, b1, b2, pri_fp8_max, pri_fp8_metas_amax,
pri_fp8_metas_scale, pri_fp8_metas_scale_inv)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape): def random_inputs_fixture(shape):
...@@ -294,6 +417,96 @@ def random_inputs_fixture(shape): ...@@ -294,6 +417,96 @@ def random_inputs_fixture(shape):
return out return out
class TestGeLu:
def ref_func(self, inputs):
func = jit(value_and_grad(lambda x: jnp.mean(jax.nn.gelu(x))))
return func(inputs)
def prim_func(self, inputs):
@jax.custom_vjp
def primitive(x):
out, _ = primitive_fwd(x)
return out
def primitive_fwd(x):
out = gelu(x)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
out = dgelu(g, x)
return (out,)
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x: jnp.mean(primitive(x)))
return func(inputs)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
x = random_inputs
prim_out, prim_grad = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
class TestGeLuFP8(TestGeLu):
def prim_func(self, inputs):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
no_use = jnp.zeros(1, jnp.float32)
@jax.custom_vjp
def primitive(x, y, z, w):
out = primitive_fwd(x)
return out
def primitive_fwd(x, y, z, w):
out, _ = gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
out = dequantize(out, x.dtype, scale_inv)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
dgelu, dgelu_trans, dbias, amax_out = dgelu_dbias_cast_transpose(
g, x, amax, scale, scale_inv, jnp.float8_e5m2, -1)
dgelu = dequantize(dgelu, x.dtype, scale_inv)
dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
return dgelu, dgelu_trans, dbias, amax_out
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3))
return func(inputs, no_use, no_use, no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
x = random_inputs
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, (2, 0, 1)),
dtype=FP8Helper.BWD_DTYPE)
class TestGatedGeLu: class TestGatedGeLu:
def ref_func(self, inputs): def ref_func(self, inputs):
......
...@@ -88,6 +88,7 @@ _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence' ...@@ -88,6 +88,7 @@ _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits" _KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads' _KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups' _KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
BASE_ATTRS = { BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_TRANSPOSE_BS: True,
...@@ -137,7 +138,25 @@ ATTRS = [{ ...@@ -137,7 +138,25 @@ ATTRS = [{
_KEY_OF_FUSE_MLP_WI: True _KEY_OF_FUSE_MLP_WI: True
}, { }, {
_KEY_OF_NUM_HEADS: 8, _KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4 _KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu',)),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}] }]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......
...@@ -818,12 +818,14 @@ class TransformerLayerAttr: ...@@ -818,12 +818,14 @@ class TransformerLayerAttr:
LYR_TYPE = 'layer_type' LYR_TYPE = 'layer_type'
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = 'zero_centered_gamma'
TRANSPOSE_BS = 'transpose_batch_sequence' TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ATTRS = [{ ATTRS = [{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -831,6 +833,7 @@ class TransformerLayerAttr: ...@@ -831,6 +833,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -838,6 +841,7 @@ class TransformerLayerAttr: ...@@ -838,6 +841,7 @@ class TransformerLayerAttr:
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -845,6 +849,7 @@ class TransformerLayerAttr: ...@@ -845,6 +849,7 @@ class TransformerLayerAttr:
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -852,6 +857,7 @@ class TransformerLayerAttr: ...@@ -852,6 +857,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -859,6 +865,7 @@ class TransformerLayerAttr: ...@@ -859,6 +865,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -866,6 +873,7 @@ class TransformerLayerAttr: ...@@ -866,6 +873,7 @@ class TransformerLayerAttr:
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -873,6 +881,7 @@ class TransformerLayerAttr: ...@@ -873,6 +881,7 @@ class TransformerLayerAttr:
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -880,6 +889,7 @@ class TransformerLayerAttr: ...@@ -880,6 +889,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -887,6 +897,7 @@ class TransformerLayerAttr: ...@@ -887,6 +897,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -894,6 +905,7 @@ class TransformerLayerAttr: ...@@ -894,6 +905,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -901,6 +913,7 @@ class TransformerLayerAttr: ...@@ -901,6 +913,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -908,6 +921,7 @@ class TransformerLayerAttr: ...@@ -908,6 +921,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -915,6 +929,7 @@ class TransformerLayerAttr: ...@@ -915,6 +929,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -922,6 +937,7 @@ class TransformerLayerAttr: ...@@ -922,6 +937,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -929,6 +945,7 @@ class TransformerLayerAttr: ...@@ -929,6 +945,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -936,6 +953,7 @@ class TransformerLayerAttr: ...@@ -936,6 +953,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -943,6 +961,7 @@ class TransformerLayerAttr: ...@@ -943,6 +961,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -950,6 +969,7 @@ class TransformerLayerAttr: ...@@ -950,6 +969,7 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: True TRANSPOSE_BS: True
}, { }, {
USE_BIAS: True, USE_BIAS: True,
...@@ -957,6 +977,23 @@ class TransformerLayerAttr: ...@@ -957,6 +977,23 @@ class TransformerLayerAttr:
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
TRANSPOSE_BS: False TRANSPOSE_BS: False
}] }]
...@@ -984,6 +1021,7 @@ class TestTransformer(TestLayer): ...@@ -984,6 +1021,7 @@ class TestTransformer(TestLayer):
use_bias = attrs[TransformerLayerAttr.USE_BIAS] use_bias = attrs[TransformerLayerAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0) bias_init = WeightInit.Constant(0.0)
layer_type = attrs[TransformerLayerAttr.LYR_TYPE] layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
enable_relative_embedding = True enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases, relative_embedding = pax_fiddle.Config(RelativePositionBiases,
num_attention_heads=num_attention_heads) num_attention_heads=num_attention_heads)
...@@ -1019,6 +1057,7 @@ class TestTransformer(TestLayer): ...@@ -1019,6 +1057,7 @@ class TestTransformer(TestLayer):
bias_init=bias_init, bias_init=bias_init,
layer_type=layer_type, layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding, enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
relative_embedding=relative_embedding, relative_embedding=relative_embedding,
drop_path=drop_path, drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence)
...@@ -1040,6 +1079,7 @@ class TestTransformer(TestLayer): ...@@ -1040,6 +1079,7 @@ class TestTransformer(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init( bias_init=TransformerEngineBaseLayer.generate_params_init(
"bias", bias_init), "bias", bias_init),
layer_type=layer_type, layer_type=layer_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
enable_relative_embedding=enable_relative_embedding, enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module, relative_embedding=relative_embedding_flax_module,
drop_path=drop_path, drop_path=drop_path,
......
...@@ -340,6 +340,29 @@ class MlpBlock(nn.Module): ...@@ -340,6 +340,29 @@ class MlpBlock(nn.Module):
return output return output
def apply_rotary_pos_emb(
inputs: jnp.ndarray,
position: jnp.ndarray,
min_timescale: int = 1,
max_timescale: int = 10000,
):
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
timescale = min_timescale * (max_timescale / min_timescale)**fraction
timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
first_part = first_part.astype(inputs.dtype)
second_part = second_part.astype(inputs.dtype)
return jnp.concatenate([first_part, second_part], axis=-1)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
...@@ -368,6 +391,7 @@ class MultiHeadAttention(nn.Module): ...@@ -368,6 +391,7 @@ class MultiHeadAttention(nn.Module):
float32_logits: bool = False # computes logits in float32 for stability. float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False
fuse_qkv: bool = True fuse_qkv: bool = True
def __post_init__(self): def __post_init__(self):
...@@ -482,6 +506,15 @@ class MultiHeadAttention(nn.Module): ...@@ -482,6 +506,15 @@ class MultiHeadAttention(nn.Module):
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
if self.enable_rotary_pos_emb:
batch_dim = 1 if self.transpose_batch_sequence else 0
seq_dim = 1 - batch_dim
position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
query = apply_rotary_pos_emb(query, position)
key = apply_rotary_pos_emb(key, position)
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
...@@ -802,6 +835,7 @@ class EncoderLayer(nn.Module): ...@@ -802,6 +835,7 @@ class EncoderLayer(nn.Module):
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
output_layernorm: bool = False output_layernorm: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False fuse_mlp_wi: bool = False
...@@ -854,6 +888,7 @@ class EncoderLayer(nn.Module): ...@@ -854,6 +888,7 @@ class EncoderLayer(nn.Module):
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
name='attention')(x, name='attention')(x,
x, x,
encoder_mask, encoder_mask,
...@@ -922,6 +957,7 @@ class DecoderLayer(nn.Module): ...@@ -922,6 +957,7 @@ class DecoderLayer(nn.Module):
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False fuse_mlp_wi: bool = False
...@@ -981,6 +1017,7 @@ class DecoderLayer(nn.Module): ...@@ -981,6 +1017,7 @@ class DecoderLayer(nn.Module):
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
name='self_attention')(x, name='self_attention')(x,
x, x,
...@@ -1014,6 +1051,7 @@ class DecoderLayer(nn.Module): ...@@ -1014,6 +1051,7 @@ class DecoderLayer(nn.Module):
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
name='encoder_decoder_attention')(y, name='encoder_decoder_attention')(y,
encoded, encoded,
......
...@@ -1349,12 +1349,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ...@@ -1349,12 +1349,6 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
Tensor *dbias, Tensor *dbias,
Tensor *workspace, Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, NVTE_CHECK(transposed_output->data.shape.size() == 2,
...@@ -1396,6 +1390,12 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ...@@ -1396,6 +1390,12 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
return; return;
} }
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) * const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
......
...@@ -313,13 +313,14 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -313,13 +313,14 @@ class LayerNormFwdPrimitive(BasePrimitive):
assert x_aval.size % hidden_size == 0 assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size x_aval.size // hidden_size, # batch size
hidden_size, hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16) jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
True, kwargs['zero_centered_gamma'], kwargs['epsilon'] True,
) kwargs['zero_centered_gamma'],
kwargs['epsilon'])
wkspace_aval = out_aval.update(shape=wkspace_info[0], wkspace_aval = out_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1])) dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = out_aval.update(shape=barrier_info[0], barrier_aval = out_aval.update(shape=barrier_info[0],
...@@ -384,14 +385,14 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -384,14 +385,14 @@ class LayerNormFwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
0, # no dgamma_part in FWD pass 0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass 0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -523,7 +524,7 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -523,7 +524,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0], dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0],
dtype=te_dtype_to_jax_dtype(dgamma_part_info[1])) dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]))
dbeta_part_aval = dbeta_aval.update(shape=dbeta_part_info[0], dbeta_part_aval = dbeta_aval.update(shape=dbeta_part_info[0],
dtype=te_dtype_to_jax_dtype(dbeta_part_info[1])) dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]))
return dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, barrier_aval, \ return dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, barrier_aval, \
dgamma_part_aval, dbeta_part_aval dgamma_part_aval, dbeta_part_aval
...@@ -559,7 +560,6 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -559,7 +560,6 @@ class LayerNormBwdPrimitive(BasePrimitive):
hidden_size = reduce(operator.mul, g_shape) hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out for output in ctx.avals_out
...@@ -706,13 +706,14 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -706,13 +706,14 @@ class RmsNormFwdPrimitive(BasePrimitive):
assert x_aval.size % hidden_size == 0 assert x_aval.size % hidden_size == 0
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size x_aval.size // hidden_size, # batch size
hidden_size, hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16) jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
False, False, kwargs['epsilon'] False,
) False,
kwargs['epsilon'])
wkspace_aval = out_aval.update(shape=wkspace_info[0], wkspace_aval = out_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1])) dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
barrier_aval = out_aval.update(shape=barrier_info[0], barrier_aval = out_aval.update(shape=barrier_info[0],
...@@ -764,14 +765,14 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -764,14 +765,14 @@ class RmsNormFwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
0, # no dgamma_part in FWD pass 0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass 0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -936,13 +937,13 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -936,13 +937,13 @@ class RmsNormBwdPrimitive(BasePrimitive):
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
dgamma_part_aval.size, dgamma_part_aval.size,
0, # no dbeta_part for RMSnorm 0, # no dbeta_part for RMSnorm
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype), jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
TEDType.kByte, # dummy dbeta_part te_dtype TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -1906,10 +1907,8 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1906,10 +1907,8 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
# prepare for the active fused-attn backend # prepare for the active fused-attn backend
batch_size = reduce(operator.mul, batch_shape) batch_size = reduce(operator.mul, batch_shape)
wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes( wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes(
batch_size, max_seqlen, num_heads, head_dim, batch_size, max_seqlen, num_heads, head_dim, scaling_factor, dropout_probability,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training
)
wkspace_aval = qkv_aval.update(shape=wkspace_info[0], wkspace_aval = qkv_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1])) dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
...@@ -2271,8 +2270,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2271,8 +2270,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
# backend determines the softmax buffer shape/dtype # backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, attn_bias_type, attn_mask_type, dropout_probability, num_heads,
num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend() q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen) softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
...@@ -2298,8 +2297,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2298,8 +2297,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes( wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim, batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training jax_dtype_to_te_dtype(q_aval.dtype), is_training)
)
wkspace_aval = q_aval.update(shape=wkspace_info[0], wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1])) dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
...@@ -2336,9 +2334,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2336,9 +2334,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
num_heads, num_gqa_groups, head_dim, wkspace_aval.size, wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training) is_training)
...@@ -2532,9 +2529,8 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2532,9 +2529,8 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
num_heads, num_gqa_groups, head_dim, wkspace_aval.size, wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training) is_training)
...@@ -2666,6 +2662,222 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, ...@@ -2666,6 +2662,222 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
is_training=is_training) is_training=is_training)
class GeluPrimitive(BasePrimitive):
"""
Gelu Froward Primitive
"""
name = "te_gelu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(x_aval):
"""
gated_gelu abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, x):
"""
gated_gelu lowering rules
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
in_dtype)
out = custom_caller(GeluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(x):
assert GeluPrimitive.inner_primitive is not None
out = GeluPrimitive.inner_primitive.bind(x)
return out
@staticmethod
def batcher(batched_args, batch_dims):
"""
gated_gelu batcher
"""
_check_valid_batch_dims(batch_dims)
assert GeluPrimitive.outer_primitive is not None
inputs, = batched_args
inputs_bdim, = batch_dims
out_bdims = inputs_bdim
return GeluPrimitive.outer_primitive.bind(inputs), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
gated_gelu infer_sharding_from_operands
"""
del result_infos # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
return out_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
gated_gelu partitioning
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
impl = GeluPrimitive.impl
return mesh, impl, out_sharding, arg_shardings
register_primitive(GeluPrimitive)
def gelu(inputs: jnp.ndarray) -> jnp.ndarray:
"""
gelu wrapper
Return geglu(inputs)
Assume inputs has two dimensions shape and the memory layout is (N..., H)
"""
return GeluPrimitive.outer_primitive.bind(inputs)
class DGeluPrimitive(BasePrimitive):
"""
Dgated Gelu Primitive
"""
name = "te_dgelu"
multiple_results = False
inner_primitive = None
outer_primitive = None
impl_static_args = ()
@staticmethod
def abstract(dz_aval, x_aval):
"""
dgelu abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert dz_aval.shape == x_aval.shape
out_aval = core.raise_to_shaped(x_aval)
return out_aval
@staticmethod
def lowering(ctx, dz, x):
"""
dgelu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
assert ir_in_shape == gi_shape
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
in_dtype, in_dtype)
out = custom_caller(DGeluPrimitive.name, args, opaque, False)
return [out]
@staticmethod
def impl(dz, x):
"""
dgelu implementation
"""
assert DGeluPrimitive.inner_primitive is not None
dx = DGeluPrimitive.inner_primitive.bind(dz, x)
return dx
@staticmethod
def batcher(batched_args, batch_dims):
"""
dgelu batcher
"""
_check_valid_batch_dims(batch_dims)
assert DGeluPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DGeluPrimitive.outer_primitive.bind(dz, x), out_bdims
@staticmethod
def infer_sharding_from_operands(mesh, arg_infos, result_infos):
"""
dgelu infer_sharding_from_operands
"""
del result_infos # Unused.
gelu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec))
return dx_sharding
@staticmethod
def partition(mesh, arg_infos, result_infos):
"""
dgelu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DGeluPrimitive.impl
return mesh, impl, out_shardings, arg_shardings
register_primitive(DGeluPrimitive)
def dgelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
"""
dgelu fusion wrapper
Return dgeglu(inputs)
"""
return DGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs)
class GatedGeluPrimitive(BasePrimitive): class GatedGeluPrimitive(BasePrimitive):
""" """
Gated Gelu Froward Primitive Gated Gelu Froward Primitive
...@@ -3190,7 +3402,7 @@ class CastFP8Primitive(BasePrimitive): ...@@ -3190,7 +3402,7 @@ class CastFP8Primitive(BasePrimitive):
x, amax, scale, scale_inv = batched_args x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, *_ = batch_dims x_bdim, amax_bdim, *_ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim out_bdims = x_bdim, amax_bdim
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype), out_bdims out_dtype=out_dtype), out_bdims
...@@ -3386,13 +3598,14 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3386,13 +3598,14 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert gamma_aval.size == beta_aval.size assert gamma_aval.size == beta_aval.size
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in type jax_dtype_to_te_dtype(x_aval.dtype), # in type
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
True, zero_centered_gamma, epsilon True,
) zero_centered_gamma,
epsilon)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
...@@ -3477,14 +3690,14 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3477,14 +3690,14 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
0, # no dgamma_part in FWD pass 0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass 0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -3636,13 +3849,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3636,13 +3849,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
rsigama_dtype = jnp.float32 rsigama_dtype = jnp.float32
wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch_size x_aval.size // hidden_size, # batch_size
hidden_size, hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(out_dtype), # out te_dtype jax_dtype_to_te_dtype(out_dtype), # out te_dtype
False, False, epsilon False,
) False,
epsilon)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
...@@ -3716,14 +3930,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3716,14 +3930,14 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
0, # no dgamma_part in FWD pass 0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass 0, # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype), jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
sm_margin, sm_margin,
...@@ -3832,6 +4046,379 @@ def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale ...@@ -3832,6 +4046,379 @@ def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale
epsilon=epsilon) epsilon=epsilon)
class GeluFp8Primitive(BasePrimitive):
"""
Gelu FP8 Primitive
"""
name = "te_gelu_fp8"
multiple_results = True
impl_static_args = (4,) #out_dtype
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
"""
te_gelu_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
"""
te_gated_gelu_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-1])
out_shape = ir_x_shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))
out = custom_caller(GeluFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
"""
to describe implementation
"""
assert GeluFp8Primitive.inner_primitive is not None
out, updated_amax = GeluFp8Primitive.inner_primitive.bind(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
return out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
"""
to describe batch rules for vmap
"""
_check_valid_batch_dims(batch_dims)
assert GeluFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, amax_bdim
return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = GeluFp8Primitive.impl(x,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_x, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(GeluFp8Primitive)
def gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
gated gelu wrapper
Return FP8(geglu(x))
"""
return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
class DGeluDBiasCastTransposePrimitive(BasePrimitive):
"""
DGelu DBias Cast Transpose Primitive
"""
name = "te_dgelu_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (5, 6, 7)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
static_axis_boundary, transpose_axis_boundary):
"""
te_dgelu_dbais_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
dbias_shape = (*x_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_info, = transformer_engine_jax.get_dgelu_dbias_ct_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = x_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dgelu_dbais_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = \
DGeluDBiasCastTransposePrimitive.abstract(*args, **kwargs)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
te_dgated_gelu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
assert ir_dz_shape == x_shape
batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (batch_szie, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary,
transpose_axis_boundary)
dbias_shape = (*x_shape[:static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype))
out = custom_caller(DGeluDBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 3})
return out
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe implementation
"""
assert DGeluDBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DGeluDBiasCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
transpose_axis_boundary):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
_check_valid_batch_dims(batch_dims)
assert DGeluDBiasCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return DGeluDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary), out_bdims
@staticmethod
def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DGeluDBiasCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DGeluDBiasCastTransposePrimitive)
def dgelu_dbias_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dgelu and dbias fusion wrapper
Return FP8(dgeglu(inputs)), dbias
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
return DGeluDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary)
class GatedGeluFp8Primitive(BasePrimitive): class GatedGeluFp8Primitive(BasePrimitive):
""" """
Gated Gelu FP8 Primitive Gated Gelu FP8 Primitive
......
...@@ -25,6 +25,10 @@ pybind11::dict Registrations() { ...@@ -25,6 +25,10 @@ pybind11::dict Registrations() {
pybind11::dict dict; pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose); dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose); dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_gelu"] = EncapsulateFunction(Gelu);
dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8);
dict["te_dgelu"] = EncapsulateFunction(DGelu);
dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu); dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
...@@ -55,6 +59,7 @@ pybind11::dict Registrations() { ...@@ -55,6 +59,7 @@ pybind11::dict Registrations() {
PYBIND11_MODULE(transformer_engine_jax, m) { PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations); m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
...@@ -62,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -62,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes); m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
......
...@@ -61,18 +61,29 @@ pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, ...@@ -61,18 +61,29 @@ pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape,
return PackOpaque(desc); return PackOpaque(desc);
} }
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype) {
CustomCallCommonWkDescriptor desc;
desc.shape.from_vector(shape);
desc.wkshape.from_vector(wkshape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.wk_dtype = wk_dtype;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size, size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype, DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType wkspace_dtype, DType barrier_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype, DType dbeta_part_dtype, bool zero_centered_gamma,
bool zero_centered_gamma, float eps, int sm_margin) { float eps, int sm_margin) {
return PackOpaque(CustomCallNormDescriptor{batch_size, hidden_size, wkspace_size, barrier_size, return PackOpaque(CustomCallNormDescriptor{
dgamma_part_sizes, dbeta_part_sizes, batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, dbeta_part_sizes,
x_dtype, w_dtype, wkspace_dtype, barrier_dtype, x_dtype, w_dtype, wkspace_dtype, barrier_dtype, dgamma_part_dtype, dbeta_part_dtype,
dgamma_part_dtype, dbeta_part_dtype, zero_centered_gamma, eps, sm_margin});
zero_centered_gamma, eps, sm_margin});
} }
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size, pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
...@@ -83,11 +94,10 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin ...@@ -83,11 +94,10 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
} }
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, float scaling_factor,
float scaling_factor, float dropout_probability, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType wkspace_dtype, bool is_training) {
DType dtype, DType wkspace_dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{ return PackOpaque(CustomCallFusedAttnDescriptor{
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim, wkspace_size, batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, dtype, wkspace_dtype, scaling_factor, dropout_probability, bias_type, mask_type, dtype, wkspace_dtype,
...@@ -149,6 +159,138 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size ...@@ -149,6 +159,138 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
input_cast_trans_tensor.data(), stream); input_cast_trans_tensor.data(), stream);
} }
void GeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
}
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output);
}
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dgelu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}
pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto gelu_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gelu_input_tensor = TensorWrapper(nullptr, gelu_input_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
}
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) { cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2}; auto input_shape = std::vector<size_t>{m, n * 2};
...@@ -251,10 +393,10 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op ...@@ -251,10 +393,10 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op
output_trans_tensor.data(), stream); output_trans_tensor.data(), stream);
} }
pybind11::tuple GetLayerNormForwardWorkspaceSizes( pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps bool is_layer_norm, bool zero_centered_gamma,
) { float eps) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
...@@ -289,13 +431,12 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes( ...@@ -289,13 +431,12 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(
std::make_pair(barrier_shape, dummy_barrier_tensor.dtype())); std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
} }
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size,
size_t workspace_size, size_t barrier_size, size_t barrier_size, bool zero_centered_gamma, float eps, void *input,
bool zero_centered_gamma, float eps, void *input, DType in_dtype, DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
void *weight, DType w_dtype, void *bias, void *output, DType out_dtype, DType out_dtype, void *workspace, DType work_dtype, void *barrier,
void *workspace, DType work_dtype, void *barrier, DType barrier_dtype, DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale,
void *mu, void *rsigma, float *amax, float *scale, float *scale_inv, float *scale_inv, cudaStream_t stream) {
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
...@@ -333,10 +474,10 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, ...@@ -333,10 +474,10 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size,
} }
} }
pybind11::tuple GetLayerNormBackwardWorkspaceSizes( pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, DType in_dtype, DType w_dtype,
bool zero_centered_gamma, float eps bool is_layer_norm, bool zero_centered_gamma,
) { float eps) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
...@@ -373,8 +514,8 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes( ...@@ -373,8 +514,8 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), nullptr, dummy_dgamma_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); dummy_barrier_tensor.data());
dbeta_part_shape = std::vector<size_t>{0, 0}; dbeta_part_shape = std::vector<size_t>{0, 0};
} }
...@@ -388,15 +529,13 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes( ...@@ -388,15 +529,13 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype())); std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
} }
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
size_t wkspace_size, size_t barrier_size, size_t barrier_size, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, bool zero_centered_gamma, float eps, void *input, DType in_dtype,
bool zero_centered_gamma, float eps, void *weight, DType w_dtype, void *ograd, void *workspace,
void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd, DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
void *workspace, DType wkspace_dtype, void *barrier, DType barrier_dtype, void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part,
void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta, DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype,
void *dgamma_part, DType dgamma_dtype,
void* dbeta_part, DType dbeta_dtype,
cudaStream_t stream) { cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
...@@ -479,10 +618,10 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -479,10 +618,10 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
mu, rsigma, amax, scale, scale_inv, stream); stream);
} }
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -513,10 +652,10 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -513,10 +652,10 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin; auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
mu, rsigma, amax, scale, scale_inv, stream); stream);
} }
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -551,11 +690,11 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -551,11 +690,11 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *dgamma_part = buffers[10]; auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11]; auto *dbeta_part = buffers[11];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps, dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream); dbeta_part_dtype, stream);
} }
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -588,10 +727,10 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -588,10 +727,10 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto sm_margin = desc.sm_margin; auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
mu, rsigma, amax, scale, scale_inv, stream); stream);
} }
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -622,10 +761,10 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz ...@@ -622,10 +761,10 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto sm_margin = desc.sm_margin; auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
mu, rsigma, amax, scale, scale_inv, stream); stream);
} }
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -659,11 +798,11 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si ...@@ -659,11 +798,11 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps, dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream); dbeta_part_dtype, stream);
} }
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -741,8 +880,8 @@ void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char ...@@ -741,8 +880,8 @@ void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char
auto *output = buffers[2]; auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape = std::vector<size_t>{desc.batch_size, desc.head_dim, auto io_shape =
desc.q_seqlen, desc.k_seqlen}; std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen}; auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype; auto dtype = desc.dtype;
...@@ -818,11 +957,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -818,11 +957,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
- common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812 - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
- common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359 - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/ */
void PrepareFusedAttnForwardAuxTensors( void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
NVTETensorPack *tensor_pack, const CustomCallFusedAttnDescriptor *desc, const CustomCallFusedAttnDescriptor *desc,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend, NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr, void *bias_buf = nullptr void *softmax_buf, void *rng_state_buf = nullptr,
) { void *bias_buf = nullptr) {
auto batch_size = desc->batch_size; auto batch_size = desc->batch_size;
auto num_heads = desc->num_heads; auto num_heads = desc->num_heads;
auto q_max_seqlen = desc->q_max_seqlen; auto q_max_seqlen = desc->q_max_seqlen;
...@@ -833,8 +972,8 @@ void PrepareFusedAttnForwardAuxTensors( ...@@ -833,8 +972,8 @@ void PrepareFusedAttnForwardAuxTensors(
tensor_pack->size = 1; tensor_pack->size = 1;
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]); Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.dptr = softmax_buf; softmax_aux->data.dptr = softmax_buf;
softmax_aux->data.shape = std::vector<size_t>{ softmax_aux->data.shape =
batch_size, num_heads, q_max_seqlen, kv_max_seqlen}; std::vector<size_t>{batch_size, num_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.dtype = desc->dtype; softmax_aux->data.dtype = desc->dtype;
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
...@@ -867,10 +1006,10 @@ void PrepareFusedAttnForwardAuxTensors( ...@@ -867,10 +1006,10 @@ void PrepareFusedAttnForwardAuxTensors(
TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()? TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/ */
void PrepareFusedAttnBackwardAuxTensors( void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
NVTETensorPack* tensor_pack, const CustomCallFusedAttnDescriptor *desc, const CustomCallFusedAttnDescriptor *desc,
NVTE_Fused_Attn_Backend backend, void* softmax_buf, void* rng_state_buf, void* bias_buf NVTE_Fused_Attn_Backend backend, void *softmax_buf,
) { void *rng_state_buf, void *bias_buf) {
// Backward calls put everything into the tensor pack for every backend // Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path // so we set dummy bias_type and backend choices here to follow the correct code path
auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
...@@ -880,17 +1019,16 @@ void PrepareFusedAttnBackwardAuxTensors( ...@@ -880,17 +1019,16 @@ void PrepareFusedAttnBackwardAuxTensors(
// correct softmax shape for max512 sequence length kernel // correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor* softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]); Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux->data.dtype = desc->dtype; softmax_aux->data.dtype = desc->dtype;
} }
} }
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float scaling_factor, float dropout_probability, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training bool is_training) {
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim}; auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
...@@ -898,17 +1036,16 @@ pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( ...@@ -898,17 +1036,16 @@ pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper( auto cu_seqlens_tensor =
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto o_tensor = TensorWrapper( auto o_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype); nullptr, std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype);
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
bias_type, mask_type, dropout_probability, num_heads, num_heads, mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim);
max_seqlen, max_seqlen, head_dim);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
...@@ -916,9 +1053,9 @@ pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( ...@@ -916,9 +1053,9 @@ pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), max_seqlen, is_training, rng_state_tensor.data(), max_seqlen, is_training, scaling_factor,
scaling_factor, dropout_probability, qkv_layout, dropout_probability, qkv_layout, bias_type, mask_type,
bias_type, mask_type, query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
...@@ -957,8 +1094,8 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -957,8 +1094,8 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
// input tensors // input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper( auto cu_seqlens_tensor =
cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32); TensorWrapper(cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// output tensors // output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
...@@ -969,9 +1106,8 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -969,9 +1106,8 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
bias_type, mask_type, dropout_probability, num_heads, num_heads, mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim);
max_seqlen, max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later) // auxiliary tensors (to be propagated to the backward pass later)
...@@ -995,10 +1131,9 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -995,10 +1131,9 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
} }
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float scaling_factor, float dropout_probability, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training bool is_training) {
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim}; auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
...@@ -1014,8 +1149,8 @@ pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( ...@@ -1014,8 +1149,8 @@ pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor = TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, auto cu_seqlens_tensor =
DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
...@@ -1084,11 +1219,10 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -1084,11 +1219,10 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
bias_type, mask_type, dropout_probability, num_heads, num_heads, mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim);
max_seqlen, max_seqlen, head_dim); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, rng_state, bias);
softmax_aux, rng_state, bias);
// cuDNN workspace // cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size}; auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
...@@ -1107,11 +1241,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -1107,11 +1241,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
} }
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim}; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
...@@ -1128,10 +1260,10 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( ...@@ -1128,10 +1260,10 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper( auto q_cu_seqlens_tensor =
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper( auto kv_cu_seqlens_tensor =
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
...@@ -1139,12 +1271,12 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( ...@@ -1139,12 +1271,12 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, is_training, scaling_factor, dropout_probability, qkv_layout,
query_workspace_tensor.data(), nullptr); bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
...@@ -1203,9 +1335,9 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -1203,9 +1335,9 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups, mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_max_seqlen, kv_max_seqlen, head_dim); head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later) // auxiliary tensors (to be propagated to the backward pass later)
...@@ -1215,25 +1347,23 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -1215,25 +1347,23 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
softmax_aux); softmax_aux);
// cuDNN workspace // cuDNN workspace
auto workspace_tensor = TensorWrapper( auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
workspace, std::vector<size_t>{descriptor.wkspace_size}, descriptor.wkspace_dtype); descriptor.wkspace_dtype);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, descriptor.is_training, scaling_factor, dropout_probability,
workspace_tensor.data(), stream); qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim}; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
...@@ -1252,10 +1382,10 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( ...@@ -1252,10 +1382,10 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper( auto q_cu_seqlens_tensor =
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper( auto kv_cu_seqlens_tensor =
nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
...@@ -1267,8 +1397,8 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( ...@@ -1267,8 +1397,8 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for FP16/BF16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
bias_type, mask_type, query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
...@@ -1325,21 +1455,21 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1325,21 +1455,21 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor = TensorWrapper( auto q_cu_seqlens_tensor =
q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32); TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = TensorWrapper( auto kv_cu_seqlens_tensor =
kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32); TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass) // auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups, mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_max_seqlen, kv_max_seqlen, head_dim); head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
softmax_aux, rng_state, bias); rng_state, bias);
// cuDNN workspace // cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size}; auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
...@@ -1352,8 +1482,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1352,8 +1482,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for FP16/BF16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
bias_type, mask_type, workspace_tensor.data(), stream); workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
} }
......
...@@ -52,6 +52,18 @@ struct CustomCallCommonDescriptor { ...@@ -52,6 +52,18 @@ struct CustomCallCommonDescriptor {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype, pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype); DType out_dtype);
struct CustomCallCommonWkDescriptor {
Shape shape;
Shape wkshape;
DType in_dtype;
DType out_dtype;
DType wk_dtype;
};
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype);
struct CustomCallNormDescriptor { struct CustomCallNormDescriptor {
size_t batch_size; size_t batch_size;
size_t hidden_size; size_t hidden_size;
...@@ -73,10 +85,10 @@ struct CustomCallNormDescriptor { ...@@ -73,10 +85,10 @@ struct CustomCallNormDescriptor {
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size, size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype, DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType wkspace_dtype, DType barrier_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dgamma_part_dtype, DType dbeta_part_dtype, DType dbeta_part_dtype, bool zero_centered_gamma,
bool zero_centered_gamma, float eps, int sm_margin); float eps, int sm_margin);
struct SoftmaxDescriptor { struct SoftmaxDescriptor {
size_t batch_size; size_t batch_size;
...@@ -110,11 +122,10 @@ struct CustomCallFusedAttnDescriptor { ...@@ -110,11 +122,10 @@ struct CustomCallFusedAttnDescriptor {
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, float scaling_factor,
float scaling_factor, float dropout_probability, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType wkspace_dtype, bool is_training);
DType dtype, DType wkspace_dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -127,6 +138,18 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o ...@@ -127,6 +138,18 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -136,20 +159,20 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t ...@@ -136,20 +159,20 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetLayerNormForwardWorkspaceSizes( pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, float eps bool is_layer_norm, bool zero_centered_gamma,
); float eps);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes( pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, DType in_dtype, DType w_dtype,
bool zero_centered_gamma, float eps bool is_layer_norm, bool zero_centered_gamma,
); float eps);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -182,39 +205,33 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -182,39 +205,33 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
std::size_t opaque_len); std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float scaling_factor, float dropout_probability, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training bool is_training);
);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float scaling_factor, float dropout_probability, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training bool is_training);
);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
......
...@@ -23,8 +23,10 @@ from ..fp8 import FP8Helper, FP8MetaPackage ...@@ -23,8 +23,10 @@ from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import layernorm_geglu_fp8_mlp, geglu from ..mlp import layernorm_geglu_fp8_mlp, geglu
from ..mlp import layernorm_gelu_fp8_mlp, gelu
from ..softmax import is_softmax_kernel_available from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -502,6 +504,14 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -502,6 +504,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -534,6 +544,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -534,6 +544,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None depth_scaling: float = None
sharding_type = None sharding_type = None
...@@ -571,6 +583,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -571,6 +583,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
) and not self.return_layernorm_output and self.enable_layernorm ) and not self.return_layernorm_output and self.enable_layernorm
if self.enable_layernorm: if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1] features = inputs.shape[-1]
...@@ -626,8 +640,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -626,8 +640,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
fp8_meta_package, fp8_meta_package,
self.layernorm_type, self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon) epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes)
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = type_safe_dot_general(y, z = type_safe_dot_general(y,
kernel, kernel,
fp8_meta_pkg=fp8_meta_package, fp8_meta_pkg=fp8_meta_package,
...@@ -730,6 +747,18 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -730,6 +747,18 @@ class LayerNormMLP(TransformerEngineBase):
Dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_1_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 1st dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
dot_2_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -765,6 +794,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -765,6 +794,9 @@ class LayerNormMLP(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None
major_sharding_type = None major_sharding_type = None
def __post_init__(self): def __post_init__(self):
...@@ -812,13 +844,28 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -812,13 +844,28 @@ class LayerNormMLP(TransformerEngineBase):
normalize_acts.append(act.lower()) normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_mlp = fuse_layernorm \ def is_gelu(acts):
geglu_act_pool = [('gelu',)]
normalize_acts = []
for act in acts:
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return tuple(normalize_acts) in geglu_act_pool
use_fused_ln_geglu_mlp = fuse_layernorm \
and (not self.use_bias) and is_geglu(self.activations) \ and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) and (self.intermediate_dropout_rate < 1e-3)
use_fused_ln_gelu_mlp = fuse_layernorm \
and self.use_bias and is_gelu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3)
# LayerNorm # LayerNorm
if self.enable_layernorm: if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1] features = inputs.shape[-1]
...@@ -883,7 +930,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -883,7 +930,10 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if use_fused_ln_mlp: ffn1_ckpt_name = 'ffn1'
ffn2_ckpt_name = 'ffn2'
if use_fused_ln_geglu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_geglu_fp8_mlp(y, out = layernorm_geglu_fp8_mlp(y,
...@@ -892,8 +942,41 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -892,8 +942,41 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_package, fp8_meta_package,
self.layernorm_type, self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon) epsilon=self.epsilon,
else: # not use_fused_ln_mlp layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
elif use_fused_ln_gelu_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
out = layernorm_gelu_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
fp8_meta_package,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name)
else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1 # DenseGeneral 1
gemm1_fp8_meta_package = None if fp8_meta_package is None \ gemm1_fp8_meta_package = None if fp8_meta_package is None \
...@@ -906,8 +989,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -906,8 +989,11 @@ class LayerNormMLP(TransformerEngineBase):
gemm1_fp8_meta_package, gemm1_fp8_meta_package,
self.layernorm_type, self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon) epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes)
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
x = type_safe_dot_general(y, x = type_safe_dot_general(y,
kernel_1, kernel_1,
fp8_meta_pkg=gemm1_fp8_meta_package, fp8_meta_pkg=gemm1_fp8_meta_package,
...@@ -924,11 +1010,14 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -924,11 +1010,14 @@ class LayerNormMLP(TransformerEngineBase):
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape) x += jnp.reshape(bias, bias_shape)
x = checkpoint_name(x, 'ffn1') x = checkpoint_name(x, ffn1_ckpt_name)
activations = [] activations = []
if is_geglu(self.activations): if is_geglu(self.activations):
z = geglu(x) z = geglu(x)
elif is_gelu(self.activations):
z = gelu(x)
z = jnp.reshape(z, (*z.shape[:-2], -1))
else: else:
x = jnp.split(x, num_activations, axis=-2) x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations): for idx, act_fn in enumerate(self.activations):
...@@ -942,6 +1031,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -942,6 +1031,8 @@ class LayerNormMLP(TransformerEngineBase):
rng_collection=self.intermediate_dropout_rng_name)( rng_collection=self.intermediate_dropout_rng_name)(
z, deterministic=deterministic) z, deterministic=deterministic)
z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
# DenseGeneral 2 # DenseGeneral 2
gemm2_fp8_meta_package = None if fp8_meta_package is None \ gemm2_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(1) else fp8_meta_package.get_package_by_gemm_idx(1)
...@@ -960,6 +1051,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -960,6 +1051,6 @@ class LayerNormMLP(TransformerEngineBase):
bias = bias.astype(self.dtype) bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, 'ffn2') out = checkpoint_name(out, ffn2_ckpt_name)
return out, ln_output # Output, layner_norm_output return out, ln_output # Output, layner_norm_output
...@@ -27,8 +27,12 @@ from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout ...@@ -27,8 +27,12 @@ from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import global_mesh_resource, num_of_devices from ..sharding import num_of_devices
from ..sharding import with_sharding_constraint from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -39,17 +43,6 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci ...@@ -39,17 +43,6 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci
Initializer = Callable[[PRNGKey, Shape, DType], Array] Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]] LogicalRules = Sequence[Tuple[str, Union[str, None]]]
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'
def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path. # Generate broadcast dims for drop_path.
...@@ -101,36 +94,8 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -101,36 +94,8 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
else: else:
rules_map[key] = [val] rules_map[key] = [val]
gsr = global_mesh_resource()
batch_dim_rule = []
if gsr.dp_resource is not None:
batch_dim_rule.append(gsr.dp_resource)
if gsr.fsdp_resource is not None and gsr.dp_resource != gsr.fsdp_resource:
batch_dim_rule.append(gsr.fsdp_resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
te_logical_axis_rules = (
(BATCH_AXES, batch_dim_rule),
(SEQLEN_AXES, None),
(HEAD_AXES, gsr.tp_resource),
(HIDDEN_AXES, None),
(HIDDEN_TP_AXES, gsr.tp_resource),
(JOINED_AXES, None),
(W_NO_SHARD_AXES, None),
(W_FSDP_AXES, gsr.fsdp_resource),
(W_TP_AXES, gsr.tp_resource),
(W_JOINED_AXES, None),
)
extended_rules = [*rules] extended_rules = [*rules]
for item in te_logical_axis_rules: for item in get_sharding_map_logic_axis_to_mesh_axis().items():
key = item[0] key = item[0]
val = item[1] val = item[1]
if key in rules_map: if key in rules_map:
...@@ -143,18 +108,6 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -143,18 +108,6 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
return tuple(extended_rules) return tuple(extended_rules)
def _with_sharding_constraint(x: Array, logical_axis_names: Shape):
assert len(x.shape) == len(logical_axis_names)
rules = extend_logical_axis_rules(tuple())
rules_dict = {}
for key, value in rules:
rules_dict[key] = value
mesh_axis_names = [rules_dict[name] for name in logical_axis_names]
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return with_sharding_constraint(x, pspec)
def _merge_mask(func, *masks: Optional[Array]): def _merge_mask(func, *masks: Optional[Array]):
masks = [m for m in masks if m is not None] masks = [m for m in masks if m is not None]
if not masks: if not masks:
...@@ -175,7 +128,10 @@ def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): ...@@ -175,7 +128,10 @@ def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
def combine_biases(*masks: Optional[Array]): def combine_biases(*masks: Optional[Array]):
"""Combine attention biases.""" """Combine attention biases."""
func = lambda a, b: a + b
def func(a, b):
return a + b
return _merge_mask(func, *masks) return _merge_mask(func, *masks)
...@@ -234,8 +190,8 @@ def core_attention(query: Array, ...@@ -234,8 +190,8 @@ def core_attention(query: Array,
attn_weights_without_groups_shape = (b, h * g, q, k) attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
attn_weights = _with_sharding_constraint(attn_weights, attn_weights = with_sharding_constraint_by_logical_axes(
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)) attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
# When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias). # When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias).
# In this case, the scale can not fused into the Softmax module. # In this case, the scale can not fused into the Softmax module.
...@@ -270,6 +226,39 @@ def core_attention(query: Array, ...@@ -270,6 +226,39 @@ def core_attention(query: Array,
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: bool):
"""
Rotary Positional Embedding
x should be in shape of
[Batch, Seqlen, ..., Hidden] if transpose_batch_sequence is False, or
[Seqlen, Batch, ..., Hidden] if transpose_batch_sequence is True.
"""
embed_dim = x.shape[-1]
half_embed_dim = embed_dim // 2
min_window = windows[0]
max_window = windows[1]
fraction = 2 * jnp.arange(0, half_embed_dim) / embed_dim
time_scales = min_window * (max_window / min_window)**fraction
time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))
batch_dim = 1 if transpose_batch_sequence else 0
seq_dim = 1 - batch_dim
positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))
sinusoidal_positions = positions / time_scales
sin = jnp.sin(sinusoidal_positions)
cos = jnp.cos(sinusoidal_positions)
x1, x2 = jnp.split(x, 2, axis=-1)
part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
part_2 = (x2 * cos + x1 * sin).astype(x.dtype)
return jnp.concatenate([part_1, part_2], axis=-1)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
...@@ -331,6 +320,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -331,6 +320,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: {'causal', 'padding'}, default = 'causal' attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation. Type of attention mask passed into softmax operation.
Introduced in v0.10.0. Introduced in v0.10.0.
enable_rotary_pos_emb: bool, default = False
Whether to enable rotary position embedding to projected query and key.
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
Indicate the min and max time-scales of rotary position embedding,
only used when :attr:`enable_rotary_pos_emb=True`
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -368,9 +364,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -368,9 +364,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False output_layernorm: bool = False
attn_mask_type: str = 'causal' attn_mask_type: str = 'causal'
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
dtype: DType = jnp.float32 dtype: DType = jnp.float32
fuse_qkv: bool = True fuse_qkv: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False # computes logits in float32 for stability. float32_logits: bool = False # computes logits in float32 for stability.
...@@ -501,6 +500,22 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -501,6 +500,22 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"Fused attention is not enabled. Because " \ f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention.") f"{reason}fall back to unfused attention.")
def generate_batch_seqlen_logical_axes(is_sharded_seq):
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
axes = [None, None]
axes[batch_dim] = BATCH_AXES
axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
return tuple(axes)
inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes(
self.enable_sequence_parallel), HIDDEN_AXES)
inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)
inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)
residual = inputs_q residual = inputs_q
if self.fuse_qkv: if self.fuse_qkv:
if is_qkvpack: if is_qkvpack:
...@@ -520,6 +535,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -520,6 +535,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=(W_JOINED_AXES, W_TP_AXES), bias_axes=(W_JOINED_AXES, W_TP_AXES),
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='qkv', name='qkv',
dtype=self.dtype)(inputs_q) dtype=self.dtype)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj') qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
...@@ -544,6 +561,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -544,6 +561,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_TP_AXES,), bias_axes=(W_TP_AXES,),
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='query')(inputs_q) name='query')(inputs_q)
if is_self_attn: if is_self_attn:
...@@ -591,6 +610,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -591,6 +610,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_TP_AXES,), bias_axes=(W_TP_AXES,),
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='query')(inputs_q) name='query')(inputs_q)
if is_self_attn: if is_self_attn:
...@@ -604,6 +625,23 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -604,6 +625,23 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
assert ln_out is not None assert ln_out is not None
residual = ln_out residual = ln_out
if self.enable_rotary_pos_emb:
if self.fuse_qkv and use_fused_attn:
if is_qkvpack:
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
else:
key, value = jnp.split(kv_proj, [1], axis=-2)
query = rotary_pos_emb(query, self.rotary_pos_emb_windows,
self.transpose_batch_sequence)
key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence)
if use_fused_attn:
if is_qkvpack:
qkv_proj = jnp.concatenate([query, key, value], axis=-2)
else:
kv_proj = jnp.concatenate([key, value], axis=-2)
if not use_fused_attn: if not use_fused_attn:
query = checkpoint_name(query, 'query_proj') query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj') key = checkpoint_name(key, 'key_proj')
...@@ -615,9 +653,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -615,9 +653,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
(SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \ (SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \ if self.transpose_batch_sequence \
else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES) else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
query = _with_sharding_constraint(query, qkv_sharding_constraint) query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
key = _with_sharding_constraint(key, qkv_sharding_constraint) key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
value = _with_sharding_constraint(value, qkv_sharding_constraint) value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
if decode: if decode:
is_initialized = self.has_variable('cache', 'cached_key') is_initialized = self.has_variable('cache', 'cached_key')
...@@ -679,7 +717,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -679,7 +717,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES) HIDDEN_AXES)
qkv_proj = _with_sharding_constraint(qkv_proj, qkv_sharding_constraint) qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj,
qkv_sharding_constraint)
x = self_fused_attn(qkv_proj, x = self_fused_attn(qkv_proj,
bias, bias,
mask, mask,
...@@ -696,8 +736,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -696,8 +736,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES) q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES) HIDDEN_AXES)
query = _with_sharding_constraint(query, q_sharding_constraint) query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
kv_proj = _with_sharding_constraint(kv_proj, kv_sharding_constraint) kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
x = cross_fused_attn(query, x = cross_fused_attn(query,
kv_proj, kv_proj,
...@@ -748,7 +788,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -748,7 +788,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
(SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \ (SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \
if self.transpose_batch_sequence \ if self.transpose_batch_sequence \
else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES) else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
x = _with_sharding_constraint(x, attn_context_sharding_constraint) x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1], out = DenseGeneral(features=inputs_q.shape[-1],
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -967,6 +1007,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -967,6 +1007,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_attention_heads=self.num_attention_heads, dtype=self.dtype, num_attention_heads=self.num_attention_heads, dtype=self.dtype,
embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'), embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
name='relpos_bias') name='relpos_bias')
enable_rotary_pos_emb: bool, default = False
Whether to enable rotary position embedding to projected query and key in MHA.
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
Indicate the min and max time-scales of rotary position embedding,
only used when :attr:`enable_rotary_pos_emb=True`
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -1016,10 +1063,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1016,10 +1063,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
self_attn_mask_type: str = 'causal' self_attn_mask_type: str = 'causal'
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
dtype: DType = jnp.float32 dtype: DType = jnp.float32
drop_path: float = 0.0 drop_path: float = 0.0
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
...@@ -1089,6 +1139,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1089,6 +1139,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim batch_dim = 1 - sequence_dim
def generate_batch_seqlen_logical_axes(is_shared_seq=None):
axes = [None, None]
is_shared_seq = self.enable_sequence_parallel if is_shared_seq is None \
else is_shared_seq
axes[batch_dim] = BATCH_AXES
axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
return tuple(axes)
attn_bias = None attn_bias = None
if self.enable_relative_embedding: if self.enable_relative_embedding:
if self.relative_embedding is None: if self.relative_embedding is None:
...@@ -1120,7 +1180,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1120,7 +1180,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
else: else:
mha_name = 'self_attention' mha_name = 'self_attention'
inputs = _with_sharding_constraint(inputs, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)) inputs = with_sharding_constraint_by_logical_axes(
inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x, residual = MultiHeadAttention( x, residual = MultiHeadAttention(
...@@ -1129,6 +1190,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1129,6 +1190,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
dropout_rate=self.attention_dropout, dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
...@@ -1140,6 +1202,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1140,6 +1202,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm, output_layernorm=self.output_layernorm,
attn_mask_type=self.self_attn_mask_type, attn_mask_type=self.self_attn_mask_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init, kernel_init=self.mha_kernel_init,
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -1161,6 +1225,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1161,6 +1225,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
broadcast_dims=self.hidden_dropout_dims, broadcast_dims=self.hidden_dropout_dims,
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic) rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
x = with_sharding_constraint_by_logical_axes(
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
residual = with_sharding_constraint_by_logical_axes(
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
x = hidden_dropout(x, deterministic) x = hidden_dropout(x, deterministic)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
...@@ -1174,12 +1243,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1174,12 +1243,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
assert encoded is not None, \ assert encoded is not None, \
"encoded is required when layer_type == TransformerLayerType.DECODER." "encoded is required when layer_type == TransformerLayerType.DECODER."
x = with_sharding_constraint_by_logical_axes(
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
y, residual = MultiHeadAttention( y, residual = MultiHeadAttention(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
dropout_rate=self.attention_dropout, dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
...@@ -1189,6 +1262,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1189,6 +1262,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm,
output_layernorm=False, # Must do LayerNorm before MHA. output_layernorm=False, # Must do LayerNorm before MHA.
attn_mask_type='padding', attn_mask_type='padding',
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
...@@ -1200,10 +1275,17 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1200,10 +1275,17 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
encoded, encoded,
encoder_decoder_mask, encoder_decoder_mask,
deterministic=deterministic) deterministic=deterministic)
y = with_sharding_constraint_by_logical_axes(
y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
residual = with_sharding_constraint_by_logical_axes(
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
y = hidden_dropout(y, deterministic) y = hidden_dropout(y, deterministic)
mlp_input = y + residual mlp_input = y + residual
mlp_input = _with_sharding_constraint(mlp_input, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)) mlp_input = with_sharding_constraint_by_logical_axes(
mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
# MlpBlock # MlpBlock
residual = mlp_input residual = mlp_input
...@@ -1228,6 +1310,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1228,6 +1310,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES), bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,), bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
name='mlp', name='mlp',
)(mlp_input, deterministic=deterministic) )(mlp_input, deterministic=deterministic)
...@@ -1235,6 +1320,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1235,6 +1320,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
assert ln_out is not None assert ln_out is not None
residual = ln_out residual = ln_out
z = with_sharding_constraint_by_logical_axes(
z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
residual = with_sharding_constraint_by_logical_axes(
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
z = hidden_dropout(z, deterministic) z = hidden_dropout(z, deterministic)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
...@@ -1243,6 +1333,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1243,6 +1333,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
z = z + residual z = z + residual
if self.output_layernorm: if self.output_layernorm:
z = with_sharding_constraint_by_logical_axes(
z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
"""JAX layernorm modules""" """JAX layernorm modules"""
from functools import partial from functools import partial
from typing import Tuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -12,6 +14,7 @@ from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd ...@@ -12,6 +14,7 @@ from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot from .dot import fp8_dot_impl, get_precision_of_fp8_dot
from .fp8 import FP8Helper, FP8MetaPackage from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
def canonicalize_layernorm_type(x): def canonicalize_layernorm_type(x):
...@@ -96,14 +99,20 @@ def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz): ...@@ -96,14 +99,20 @@ def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule) _layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
def layernorm_fp8_dot(x: jnp.ndarray, def layernorm_fp8_dot(
kernel: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, kernel: jnp.ndarray,
beta: jnp.ndarray, gamma: jnp.ndarray,
fp8_meta_pkg: FP8MetaPackage, beta: jnp.ndarray,
layernorm_type: str, fp8_meta_pkg: FP8MetaPackage,
zero_centered_gamma: bool = False, layernorm_type: str,
epsilon: float = 1e-6) -> jnp.ndarray: zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[
str, ...] = None, # The logic axes of sharding constraint to the layernorm input.
dot_input_axes: Tuple[str,
...] = None # The logic axes of sharding constraint to the dot input.
) -> jnp.ndarray:
""" """
Layernorm + FP8 GEMM Layernorm + FP8 GEMM
""" """
...@@ -114,18 +123,21 @@ def layernorm_fp8_dot(x: jnp.ndarray, ...@@ -114,18 +123,21 @@ def layernorm_fp8_dot(x: jnp.ndarray,
fwd_dtype = FP8Helper.FWD_DTYPE fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE bwd_dtype = FP8Helper.BWD_DTYPE
output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv, output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon) layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_input_axes)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12)) @partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14))
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype, scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float): bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...]):
output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv, output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype, layernorm_type, fwd_dtype, bwd_dtype,
zero_centered_gamma, epsilon) zero_centered_gamma, epsilon, layernorm_input_axes,
dot_input_axes)
return output return output
...@@ -142,7 +154,9 @@ def _layernorm_fp8_dot_fwd_rule( ...@@ -142,7 +154,9 @@ def _layernorm_fp8_dot_fwd_rule(
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
zero_centered_gamma, zero_centered_gamma,
epsilon): epsilon,
layernorm_input_axes,
dot_input_axes):
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,) k_contracting_dims = (0,)
...@@ -156,6 +170,8 @@ def _layernorm_fp8_dot_fwd_rule( ...@@ -156,6 +170,8 @@ def _layernorm_fp8_dot_fwd_rule(
x_scale = scale[gemm_x_idx] x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx] x_scale_inv = scale_inv[gemm_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8( ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x, x,
...@@ -191,6 +207,8 @@ def _layernorm_fp8_dot_fwd_rule( ...@@ -191,6 +207,8 @@ def _layernorm_fp8_dot_fwd_rule(
casted_kernel, updated_kernel_amax = \ casted_kernel, updated_kernel_amax = \
cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype) cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)
# (batch..., hidden_in) x (hidden_in, hidden_out...) # (batch..., hidden_in) x (hidden_in, hidden_out...)
output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype, output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
...@@ -209,6 +227,8 @@ def _layernorm_fp8_dot_bwd_rule( ...@@ -209,6 +227,8 @@ def _layernorm_fp8_dot_bwd_rule(
bwd_dtype, bwd_dtype,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
ctx, ctx,
grad): grad):
ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \ ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
...@@ -243,6 +263,7 @@ def _layernorm_fp8_dot_bwd_rule( ...@@ -243,6 +263,7 @@ def _layernorm_fp8_dot_bwd_rule(
(g_for_dgrad_constracting_dim, k_constracting_dim), (g_for_dgrad_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad, dx, dgamma, dbeta = layernorm_bwd(dgrad,
x, x,
......
...@@ -3,13 +3,16 @@ ...@@ -3,13 +3,16 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX MLP modules""" """JAX MLP modules"""
from typing import List from typing import List, Tuple
from functools import partial from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from .cpp_extensions import cast_fp8, transpose, cast_transpose from .cpp_extensions import cast_fp8, transpose, cast_transpose
from .cpp_extensions import gelu as te_gelu
from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose
from .cpp_extensions import gated_gelu, gated_gelu_fp8 from .cpp_extensions import gated_gelu, gated_gelu_fp8
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
...@@ -17,6 +20,40 @@ from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd ...@@ -17,6 +20,40 @@ from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
from .layernorm import canonicalize_layernorm_type from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
def gelu(x: jnp.ndarray):
"""
Gelu
"""
output = _gelu(x)
return output
@partial(jax.custom_vjp)
def _gelu(x: jnp.ndarray):
geglu_output, _ = _gelu_fwd_rule(x)
return geglu_output
def _gelu_fwd_rule(x):
geglu_output = te_gelu(x)
return geglu_output, (x,)
def _gelu_bwd_rule(ctx, g):
x, = ctx
assert x.dtype == g.dtype
dx = dgelu(g, x)
dx = jnp.reshape(dx, x.shape)
return (dx,)
_gelu.defvjp(_gelu_fwd_rule, _gelu_bwd_rule)
def geglu(x: jnp.ndarray): def geglu(x: jnp.ndarray):
...@@ -47,9 +84,9 @@ def _geglu_bwd_rule(ctx, g): ...@@ -47,9 +84,9 @@ def _geglu_bwd_rule(ctx, g):
x, = ctx x, = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dgelu = dgated_gelu(g, x) dx = dgated_gelu(g, x)
dgelu = jnp.reshape(dgelu, x.shape) dx = jnp.reshape(dx, x.shape)
return (dgelu,) return (dx,)
_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule) _geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule)
...@@ -62,7 +99,12 @@ def layernorm_geglu_fp8_mlp(x: jnp.ndarray, ...@@ -62,7 +99,12 @@ def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
fp8_gemm_pkg: FP8MetaPackage, fp8_gemm_pkg: FP8MetaPackage,
layernorm_type: str, layernorm_type: str,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6) -> jnp.ndarray: epsilon: float = 1e-6,
layernorm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = 'ffn1',
ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray:
""" """
Layernorm + GEMM1 + GeGLU + GEMM2 Layernorm + GEMM1 + GeGLU + GEMM2
""" """
...@@ -88,19 +130,26 @@ def layernorm_geglu_fp8_mlp(x: jnp.ndarray, ...@@ -88,19 +130,26 @@ def layernorm_geglu_fp8_mlp(x: jnp.ndarray,
output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale, output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, fwd_dtype, bwd_dtype, layernorm_type, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon) zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
ffn2_ckpt_name)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13)) @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray,
amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str,
zero_centered_gamma: bool, epsilon: float): zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str):
output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax,
scale, scale_inv, fwd_dtype, bwd_dtype, scale, scale_inv, fwd_dtype, bwd_dtype,
layernorm_type, zero_centered_gamma, epsilon) layernorm_type, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name)
return output return output
...@@ -118,7 +167,12 @@ def _layernorm_geglu_fp8_mlp_fwd_rule( ...@@ -118,7 +167,12 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
layernorm_type, layernorm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon): epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name):
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out) # Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out)
...@@ -141,6 +195,8 @@ def _layernorm_geglu_fp8_mlp_fwd_rule( ...@@ -141,6 +195,8 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
x_scale = scale[gemm1_x_idx] x_scale = scale[gemm1_x_idx]
x_scale_inv = scale_inv[gemm1_x_idx] x_scale_inv = scale_inv[gemm1_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8( ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x, x,
...@@ -175,10 +231,13 @@ def _layernorm_geglu_fp8_mlp_fwd_rule( ...@@ -175,10 +231,13 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
casted_kernel_1, updated_kernel_1_amax = \ casted_kernel_1, updated_kernel_1_amax = \
cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype) cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)
# (batch..., hidden_in) x (hidden_in, 2, hidden_out) # (batch..., hidden_in) x (hidden_in, 2, hidden_out)
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype, dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (0,)), (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
...@@ -191,6 +250,8 @@ def _layernorm_geglu_fp8_mlp_fwd_rule( ...@@ -191,6 +250,8 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
geglu_out_scale, geglu_out_scale_inv, geglu_out_scale, geglu_out_scale_inv,
fwd_dtype) fwd_dtype)
casted_geglu_out = with_sharding_constraint_by_logical_axes(casted_geglu_out, dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx] kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
...@@ -201,6 +262,7 @@ def _layernorm_geglu_fp8_mlp_fwd_rule( ...@@ -201,6 +262,7 @@ def _layernorm_geglu_fp8_mlp_fwd_rule(
dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv, dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)), kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1, ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
...@@ -215,6 +277,11 @@ def _layernorm_geglu_fp8_mlp_bwd_rule( ...@@ -215,6 +277,11 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
layernorm_type, layernorm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
ctx, ctx,
grad): grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \ x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \
...@@ -228,6 +295,9 @@ def _layernorm_geglu_fp8_mlp_bwd_rule( ...@@ -228,6 +295,9 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
grad_scale = scale[gemm2_grad_idx] grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx] grad_scale_inv = scale_inv[gemm2_grad_idx]
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, casted_grad_t, updated_grad_amax = \ casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype, cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-1) static_axis_boundary=-1, transpose_axis_boundary=-1)
...@@ -248,6 +318,8 @@ def _layernorm_geglu_fp8_mlp_bwd_rule( ...@@ -248,6 +318,8 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
grad.dtype, (x_contracting_dims, (1,)), grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0) gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgeglu_amax = amax[gemm1_grad_idx, 0:1] dgeglu_amax = amax[gemm1_grad_idx, 0:1]
...@@ -280,6 +352,8 @@ def _layernorm_geglu_fp8_mlp_bwd_rule( ...@@ -280,6 +352,8 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)), grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad_1, dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
x, x,
...@@ -309,3 +383,315 @@ def _layernorm_geglu_fp8_mlp_bwd_rule( ...@@ -309,3 +383,315 @@ def _layernorm_geglu_fp8_mlp_bwd_rule(
_layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule, _layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule,
_layernorm_geglu_fp8_mlp_bwd_rule) _layernorm_geglu_fp8_mlp_bwd_rule)
def layernorm_gelu_fp8_mlp(x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
biases: List[jnp.ndarray],
fp8_gemm_pkg: FP8MetaPackage,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = 'ffn1',
ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray:
"""
Layernorm + GEMM1 + bias + GeLU + GEMM2 + bias
"""
assert len(kernels) == 2
assert fp8_gemm_pkg.num_of_gemm == len(kernels)
kernel_1 = kernels[0]
kernel_2 = kernels[1]
bias_1 = biases[0]
bias_2 = biases[1]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output = _layernorm_gelu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
ffn2_ckpt_name)
return output
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
def _layernorm_gelu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
epsilon: float, layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str):
output, _ = _layernorm_gelu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2,
fp8_max, amax, scale, scale_inv, fwd_dtype,
bwd_dtype, layernorm_type, zero_centered_gamma,
epsilon, layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name)
return output
def _layernorm_gelu_fp8_mlp_fwd_rule(
x,
gamma,
beta,
kernel_1,
kernel_2,
bias_1,
bias_2,
fp8_max,
amax,
scale,
scale_inv,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name):
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out)
assert len(kernel_1.shape) == 3
assert kernel_1.shape[-2] == 1
assert len(kernel_2.shape) == 2
x_contracting_dims = (len(x.shape) - 1,)
xt_batch_dims = tuple(range(1, x.ndim))
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0]
# Squeeze act axis
# (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
kernel_1 = jnp.squeeze(kernel_1, axis=-2)
amax = FP8Helper.update_amax_history(amax)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm1_x_idx, 0:1]
x_scale = scale[gemm1_x_idx]
x_scale_inv = scale_inv[gemm1_x_idx]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
x,
gamma,
beta,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
gamma,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
epsilon=epsilon)
mu = None
assert x.shape == ln_out.shape
kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
kernel_1_scale = scale[gemm1_kernel_idx]
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_1, updated_kernel_1_amax = \
cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape
dot_1_output += jnp.reshape(bias_1, bias_1_shape)
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
gelu_out_amax = amax[gemm2_x_idx, 0:1]
gelu_out_scale = scale[gemm2_x_idx]
gelu_out_scale_inv = scale_inv[gemm2_x_idx]
# (batch..., hidden_in) -> (batch..., hidden)
casted_gelu_out, updated_gelu_amax = gelu_fp8(dot_1_output, gelu_out_amax, gelu_out_scale,
gelu_out_scale_inv, fwd_dtype)
casted_gelu_out = with_sharding_constraint_by_logical_axes(casted_gelu_out, dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_gelu_out, casted_kernel_2, gelu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape
dot_2_output += jnp.reshape(bias_2, bias_2_shape)
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_gelu_amax,
updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims,
bias_1.shape, bias_2.shape)
return dot_2_output, ctx
def _layernorm_gelu_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
layernorm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
ctx,
grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, \
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1]
grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx]
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=-1)
casted_gelu_out_t = transpose(casted_gelu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_2 = jnp.sum(grad, axis=(i for i in range(grad.ndim - 1)))
dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
wgrad_2 = fp8_dot_impl(casted_gelu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dgelu_amax = amax[gemm1_grad_idx, 0:1]
dgelu_scale = scale[gemm1_grad_idx]
dgelu_scale_inv = scale_inv[gemm1_grad_idx]
casted_dgelu, casted_dgelu_t, dbias_1, updated_dgelu_amax = dgelu_dbias_cast_transpose(
dgrad_2,
dot_1_output,
dgelu_amax,
dgelu_scale,
dgelu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgelu_t, gemm1_x_scale_inv, dgelu_scale_inv, grad.dtype,
(xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# Expand act axis to match the shape with the given kernel_1
wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2)
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dgelu, casted_kernel_1, dgelu_scale_inv, kernel_1_scale_inv,
grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(updated_dgelu_amax[0])
amax = amax.at[gemm2_x_idx, 0].set(updated_gelu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
fp8_max, amax, scale, scale_inv
_layernorm_gelu_fp8_mlp.defvjp(_layernorm_gelu_fp8_mlp_fwd_rule, _layernorm_gelu_fp8_mlp_bwd_rule)
...@@ -77,6 +77,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -77,6 +77,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
attn_mask_type: str = 'causal' attn_mask_type: str = 'causal'
fuse_qkv: bool = True fuse_qkv: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False float32_logits: bool = False
...@@ -109,6 +110,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -109,6 +110,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
fuse_qkv=self.fuse_qkv, fuse_qkv=self.fuse_qkv,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits) float32_logits=self.float32_logits)
...@@ -156,11 +158,14 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -156,11 +158,14 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits: bool = False float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal' self_attn_mask_type: str = 'causal'
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0 drop_path: float = 0.0
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
...@@ -221,11 +226,14 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -221,11 +226,14 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits=self.float32_attention_logits, float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type, layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type, self_attn_mask_type=self.self_attn_mask_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
enable_relative_embedding=self.enable_relative_embedding, enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module, relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path, drop_path=self.drop_path,
fuse_qkv_params=self.fuse_qkv_params, fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init) scaled_query_init=self.scaled_query_init)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
""" """
Sharding Meta for xmap with CustomCall Sharding Meta for xmap with CustomCall
""" """
import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
...@@ -16,6 +16,19 @@ from jax.sharding import PartitionSpec ...@@ -16,6 +16,19 @@ from jax.sharding import PartitionSpec
_PXLA_THREAD_RESOURCES = pxla.thread_resources _PXLA_THREAD_RESOURCES = pxla.thread_resources
# Axis Names
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
SEQLEN_TP_AXES = 'nvte_seqlen_tp'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'
def _get_mesh_info(resource: str): def _get_mesh_info(resource: str):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
...@@ -24,6 +37,81 @@ def _get_mesh_info(resource: str): ...@@ -24,6 +37,81 @@ def _get_mesh_info(resource: str):
return mesh.shape[resource], resource return mesh.shape[resource], resource
def get_sharding_map_logic_axis_to_mesh_axis():
"""
Generate a dict to map logical axes to mesh axes.
"""
gsr = global_mesh_resource()
IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False)))
batch_resources = [gsr.fsdp_resource, gsr.dp_resource] if IS_FSDP_OUTER \
else [gsr.dp_resource, gsr.fsdp_resource]
batch_dim_rule = []
for resource in batch_resources:
if resource is not None and resource not in batch_dim_rule:
batch_dim_rule.append(resource)
if len(batch_dim_rule) <= 0:
batch_dim_rule = None
elif len(batch_dim_rule) == 1:
batch_dim_rule = batch_dim_rule[0]
else:
batch_dim_rule = tuple(batch_dim_rule)
te_logical_axis_to_mesh_axis = {
BATCH_AXES: batch_dim_rule,
SEQLEN_AXES: None,
SEQLEN_TP_AXES: gsr.tp_resource,
HEAD_AXES: gsr.tp_resource,
HIDDEN_AXES: None,
HIDDEN_TP_AXES: gsr.tp_resource,
JOINED_AXES: None,
W_NO_SHARD_AXES: None,
W_FSDP_AXES: gsr.fsdp_resource,
W_TP_AXES: gsr.tp_resource,
W_JOINED_AXES: None,
}
return te_logical_axis_to_mesh_axis
def generate_pspec(logical_axis_names):
"""
Convert logical axes to PartitionSpec
"""
rules = get_sharding_map_logic_axis_to_mesh_axis()
mesh_axis_names = [rules[name] for name in logical_axis_names]
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return pspec
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
if pspec is None:
return x
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
"""
if logical_axis_names is None:
return x
assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec)
def get_all_mesh_axes(): def get_all_mesh_axes():
""" """
Get all name of mesh axes Get all name of mesh axes
...@@ -42,17 +130,6 @@ def get_padded_spec(spec, ndim): ...@@ -42,17 +130,6 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec)) return spec + (None,) * (ndim - len(spec))
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
"""
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return x
return jax.lax.with_sharding_constraint(x, pspec)
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str): def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
""" """
A wrapper function to invoke lax.p* operations, like psum. A wrapper function to invoke lax.p* operations, like psum.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment