"vscode:/vscode.git/clone" did not exist on "4a91d110c39114f2be014211f67a7e1f60b2b75e"
Unverified Commit 71e51eae authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Migrating from Xmap to Custom Partitioning for All Custom Calls (#472)



* Refactor sharding.py for the further custom_partitioning migration
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of LayerNorm/RMSNorm from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of all kinds of softmax from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* WAR to LN/RMSN_fp8 before migrating to CP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix the wrong order of parameters of bwd of LN/RMSN_fp8.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Following review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Force the hidden dim in Norm ops to no sharding and add warning msg.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of self-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating both FWD and BWD of cross-fused-attn from xmap to custom_partitioning.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* add gelu and dgelu.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Reuse fwd_rule in VJP functions for attentions
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply native FP8 Dtypes to fp8.py
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating cast_and_transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Migrating transpose from xmap to custom_partitioning
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Apply XLA pattern match to perform FP8 GEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate layernorm_fp8 to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify code style
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Extend supported of Transpose with FP8
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernorm_fp8_dot based on migrated custom calls.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Renaming variables and publish NVTE_FP8_COLLECTION_NAME
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Replace Q/DQ custom calls with native XLA implementations
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* migrate gelu_fp to custom_partitioning.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Miner fix
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support custom calls with mutli-dims
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Support gerneral dot indices in _fp8_dot_impl
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Implementing layernrom_geglu_fp8_mlp
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove GEMM custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Remove xmap related code
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix typo and add query-function to FP8MetaPackage
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix some bugs of custom calls
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fix CT's bugs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update UTs/eaxmaples to adapt to the API changes.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Unify kernel initilization in MLP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Modifing with code review's feedback
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update README and Add deprecating warning to *ShardingType
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Canonicalize the dtype
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding assertion for non-supported batch dims.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding doc/examples to _multidim_transpose
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Apply dtype-based rtol/atol to UTs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Deprecate QKV_INTERLEAVED enum
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Skip test_distributed_custom_ops.py
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the wrong sharding of bias in SelfAttn
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* WAR to fix the wrong cu_seqlen of MHA when DP/FSDP enabled
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding distributed ops unit-tests
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license to test_distributed_*
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Follow review feedback to modify
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Use total bytes involved in collective ops as criteria.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarDonglin Yang <dongliny@nvidia.com>
parent 7976bd00
...@@ -126,8 +126,6 @@ Flax ...@@ -126,8 +126,6 @@ Flax
for _ in range(10): for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp) loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
# Update FP8 metas
other_variables = te.update_fp8_metas(other_grads)
.. overview-end-marker-do-not-remove .. overview-end-marker-do-not-remove
......
...@@ -58,20 +58,18 @@ class Net(nn.Module): ...@@ -58,20 +58,18 @@ class Net(nn.Module):
x = te_flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,), bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x return x
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch.""" """Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
...@@ -87,13 +85,11 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): ...@@ -87,13 +85,11 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
var_collect, grads = flax.core.pop(grads, PARAMS_KEY) var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads) state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
return state, loss, accuracy, var_collect return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn): def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch.""" """Train for a single epoch."""
train_ds_size = len(train_ds['sentence']) train_ds_size = len(train_ds['sentence'])
steps_per_epoch = train_ds_size // batch_size steps_per_epoch = train_ds_size // batch_size
...@@ -108,7 +104,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f ...@@ -108,7 +104,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f
batch_masks = train_ds['mask'][perm, ...] batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...] batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs, use_fp8) batch_labels, var_collect, rngs)
epoch_loss.append(loss) epoch_loss.append(loss)
epoch_accuracy.append(accuracy) epoch_accuracy.append(accuracy)
...@@ -206,9 +202,8 @@ def get_datasets(max_seq_len): ...@@ -206,9 +202,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels): def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str( assert "fp8_" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
rngs, True))
def get_params_pspec(sharding_rules, abs_var_collect): def get_params_pspec(sharding_rules, abs_var_collect):
...@@ -269,7 +264,8 @@ def train_and_evaluate(args): ...@@ -269,7 +264,8 @@ def train_and_evaluate(args):
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8, with te.fp8_autocast(args.use_fp8,
sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)):
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
...@@ -297,7 +293,7 @@ def train_and_evaluate(args): ...@@ -297,7 +293,7 @@ def train_and_evaluate(args):
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None) out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,)) pjit_train_step = pjit(train_step, in_shardings, out_shardings)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None) out_shardings = (None, None)
...@@ -310,7 +306,7 @@ def train_and_evaluate(args): ...@@ -310,7 +306,7 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
return None return None
...@@ -320,8 +316,7 @@ def train_and_evaluate(args): ...@@ -320,8 +316,7 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8, state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step)
pjit_train_step)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step) var_collect, pjit_eval_step)
......
...@@ -52,17 +52,15 @@ class Net(nn.Module): ...@@ -52,17 +52,15 @@ class Net(nn.Module):
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x return x
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch.""" """Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
...@@ -78,13 +76,11 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): ...@@ -78,13 +76,11 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
var_collect, grads = flax.core.pop(grads, PARAMS_KEY) var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads) state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
return state, loss, accuracy, var_collect return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn): def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch.""" """Train for a single epoch."""
train_ds_size = len(train_ds['sentence']) train_ds_size = len(train_ds['sentence'])
steps_per_epoch = train_ds_size // batch_size steps_per_epoch = train_ds_size // batch_size
...@@ -99,7 +95,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f ...@@ -99,7 +95,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f
batch_masks = train_ds['mask'][perm, ...] batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...] batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks, state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs, use_fp8) batch_labels, var_collect, rngs)
epoch_loss.append(loss) epoch_loss.append(loss)
epoch_accuracy.append(accuracy) epoch_accuracy.append(accuracy)
...@@ -197,9 +193,8 @@ def get_datasets(max_seq_len): ...@@ -197,9 +193,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels): def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str( assert "fp8_" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
rngs, True))
def get_params_pspec(sharding_rules, abs_var_collect): def get_params_pspec(sharding_rules, abs_var_collect):
...@@ -252,7 +247,8 @@ def train_and_evaluate(args): ...@@ -252,7 +247,8 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8, sharding_resource=te.ShardingResource(DEVICE_DP_AXIS)): with te.fp8_autocast(args.use_fp8,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)):
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
...@@ -279,7 +275,7 @@ def train_and_evaluate(args): ...@@ -279,7 +275,7 @@ def train_and_evaluate(args):
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None) out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,)) pjit_train_step = pjit(train_step, in_shardings, out_shardings)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None) out_shardings = (None, None)
...@@ -292,7 +288,7 @@ def train_and_evaluate(args): ...@@ -292,7 +288,7 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
return None return None
...@@ -302,8 +298,7 @@ def train_and_evaluate(args): ...@@ -302,8 +298,7 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8, state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step)
pjit_train_step)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step) var_collect, pjit_eval_step)
......
...@@ -61,13 +61,11 @@ class Net(nn.Module): ...@@ -61,13 +61,11 @@ class Net(nn.Module):
x = te_flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256, x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,), bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW,
dtype=jnp.bfloat16)(x) dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
...@@ -106,7 +104,7 @@ def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False ...@@ -106,7 +104,7 @@ def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False
return global_input_shape, named_sharding, inputs return global_input_shape, named_sharding, inputs
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch.""" """Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
...@@ -122,14 +120,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): ...@@ -122,14 +120,12 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
var_collect, grads = flax.core.pop(grads, PARAMS_KEY) var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads) state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
return state, loss, accuracy, var_collect return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn, mesh, def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn, mesh, inputs_pspec,
inputs_pspec, masks_pspec, labels_pspec): masks_pspec, labels_pspec):
"""Train for a single epoch.""" """Train for a single epoch."""
total_batch_size = len(train_ds['sentence']) total_batch_size = len(train_ds['sentence'])
...@@ -164,7 +160,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f ...@@ -164,7 +160,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_f
label_named_sharding, [batch_label]) label_named_sharding, [batch_label])
state, loss, accuracy, var_collect = train_fn(state, shard_input, shard_mask, shard_label, state, loss, accuracy, var_collect = train_fn(state, shard_input, shard_mask, shard_label,
var_collect, rngs, use_fp8) var_collect, rngs)
epoch_loss.append(loss) epoch_loss.append(loss)
epoch_accuracy.append(accuracy) epoch_accuracy.append(accuracy)
...@@ -280,9 +276,8 @@ def get_datasets(max_seq_len): ...@@ -280,9 +276,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels): def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str( assert "fp8_" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
rngs, True))
def get_params_pspec(sharding_rules, abs_var_collect): def get_params_pspec(sharding_rules, abs_var_collect):
...@@ -350,7 +345,8 @@ def train_and_evaluate(args): ...@@ -350,7 +345,8 @@ def train_and_evaluate(args):
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8, with te.fp8_autocast(args.use_fp8,
sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)):
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
...@@ -378,7 +374,7 @@ def train_and_evaluate(args): ...@@ -378,7 +374,7 @@ def train_and_evaluate(args):
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None) out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,)) pjit_train_step = pjit(train_step, in_shardings, out_shardings)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None) out_shardings = (None, None)
...@@ -391,7 +387,7 @@ def train_and_evaluate(args): ...@@ -391,7 +387,7 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
else: else:
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
...@@ -400,8 +396,8 @@ def train_and_evaluate(args): ...@@ -400,8 +396,8 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8, state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step,
pjit_train_step, shard_mesh, inputs_pspec, masks_pspec, labels_pspec) shard_mesh, inputs_pspec, masks_pspec, labels_pspec)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step, shard_mesh, var_collect, pjit_eval_step, shard_mesh,
......
...@@ -56,7 +56,7 @@ class Net(nn.Module): ...@@ -56,7 +56,7 @@ class Net(nn.Module):
@partial(jax.jit, static_argnums=6) @partial(jax.jit, static_argnums=6)
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch.""" """Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
...@@ -72,13 +72,11 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8): ...@@ -72,13 +72,11 @@ def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
var_collect, grads = flax.core.pop(grads, PARAMS_KEY) var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads) state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
return state, loss, accuracy, var_collect return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch.""" """Train for a single epoch."""
train_ds_size = len(train_ds['sentence']) train_ds_size = len(train_ds['sentence'])
steps_per_epoch = train_ds_size // batch_size steps_per_epoch = train_ds_size // batch_size
...@@ -93,7 +91,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): ...@@ -93,7 +91,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8):
batch_masks = train_ds['mask'][perm, ...] batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...] batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks, state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs, use_fp8) batch_labels, var_collect, rngs)
epoch_loss.append(loss) epoch_loss.append(loss)
epoch_accuracy.append(accuracy) epoch_accuracy.append(accuracy)
...@@ -192,9 +190,8 @@ def get_datasets(max_seq_len): ...@@ -192,9 +190,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels): def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str( assert "fp8_" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect, jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
rngs, True))
def train_and_evaluate(args): def train_and_evaluate(args):
...@@ -228,7 +225,7 @@ def train_and_evaluate(args): ...@@ -228,7 +225,7 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng}
train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8) train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
return None return None
...@@ -238,7 +235,7 @@ def train_and_evaluate(args): ...@@ -238,7 +235,7 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8) state, train_ds, args.batch_size, rngs, var_collect)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
......
...@@ -75,15 +75,13 @@ def apply_model(state, images, labels, var_collect, rngs=None): ...@@ -75,15 +75,13 @@ def apply_model(state, images, labels, var_collect, rngs=None):
@partial(jax.jit, static_argnums=2) @partial(jax.jit, static_argnums=2)
def update_model(state, grads, use_fp8): def update_model(state, grads):
"""Update model params and FP8 meta.""" """Update model params and FP8 meta."""
state = state.apply_gradients(grads=grads[PARAMS_KEY]) state = state.apply_gradients(grads=grads[PARAMS_KEY])
if use_fp8:
grads = te.update_fp8_metas(grads)
return state, grads return state, grads
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch.""" """Train for a single epoch."""
train_ds_size = len(train_ds['image']) train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size steps_per_epoch = train_ds_size // batch_size
...@@ -97,7 +95,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8): ...@@ -97,7 +95,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8):
batch_images = train_ds['image'][perm, ...] batch_images = train_ds['image'][perm, ...]
batch_labels = train_ds['label'][perm, ...] batch_labels = train_ds['label'][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs) grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs)
state, var_collect = update_model(state, grads, use_fp8) state, var_collect = update_model(state, grads)
epoch_loss.append(loss) epoch_loss.append(loss)
epoch_accuracy.append(accuracy) epoch_accuracy.append(accuracy)
...@@ -150,7 +148,7 @@ def get_datasets(): ...@@ -150,7 +148,7 @@ def get_datasets():
def check_fp8(state, var_collect, input_shape, label_shape): def check_fp8(state, var_collect, input_shape, label_shape):
"Check if model includes FP8." "Check if model includes FP8."
assert "Float8" in str( assert "f8_" in str(
jax.make_jaxpr(apply_model)(state, jnp.empty(input_shape, dtype=jnp.bfloat16), jax.make_jaxpr(apply_model)(state, jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect)) jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect))
...@@ -195,7 +193,7 @@ def train_and_evaluate(args): ...@@ -195,7 +193,7 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8) state, train_ds, args.batch_size, rngs, var_collect)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
print(f"Epoch: {epoch:>2} " print(f"Epoch: {epoch:>2} "
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
from itertools import product
from transformer_engine.jax.sharding import ShardingType
from transformer_engine.jax.softmax import SoftmaxType
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType\
class DistributedConfigsHelper(object):
def __init__(self, num_gpus=len(jax.devices())):
super().__init__()
self.device_count = min(num_gpus, 8)
if self.device_count < 2:
self.layernorm_refs = []
self.softmax_types = []
self.softmax_refs = []
self.self_attn_bias_types = []
self.self_attn_mask_types = []
self.self_attn_refs = []
self.cross_attn_mask_types = []
self.cross_attn_refs = []
return
mesh_configs = [
((self.device_count, 1), ("dp", None), ShardingType.DP),
((self.device_count, 1), ("tp", None), ShardingType.TP_COL),
((self.device_count, 1), ("tp", None), ShardingType.TP_ROW)
]
if self.device_count >= 4:
mesh_configs += [
((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_COL),
((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW),
]
if self.device_count >= 6:
mesh_configs += [
((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_COL),
((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_ROW),
]
layernorm_collectives = {
ShardingType.DP : {'all-reduce': 2, 'other': 0},
ShardingType.TP_COL : {'all-reduce': 0, 'other': 0},
ShardingType.DP_TP_COL : {'all-reduce': 2, 'other': 0}
}
self.layernorm_refs = [
mesh_config + (layernorm_collectives[mesh_config[2]], ) \
for mesh_config in mesh_configs \
if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW)
]
self.softmax_types = [
SoftmaxType.SCALED,
SoftmaxType.SCALED_MASKED,
SoftmaxType.SCALED_UPPER_TRIANG_MASKED
]
softmax_collectives = {
ShardingType.DP : {'all-reduce': 1, 'other': 0},
ShardingType.TP_COL : {'all-reduce': 1, 'other': 0},
ShardingType.TP_ROW : {'all-reduce': 1, 'other': 0},
ShardingType.DP_TP_COL : {'all-reduce': 1, 'other': 0},
ShardingType.DP_TP_ROW : {'all-reduce': 1, 'other': 0}
}
self.softmax_refs = [
mesh_config + (softmax_collectives[mesh_config[2]], ) for mesh_config in mesh_configs
]
self.self_attn_bias_types = [
AttnBiasType.NO_BIAS,
AttnBiasType.PRE_SCALE_BIAS,
AttnBiasType.POST_SCALE_BIAS
]
self.self_attn_mask_types = [
AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_MASK,
AttnMaskType.NO_MASK
]
self_attn_collectives = {
ShardingType.DP : {
AttnBiasType.NO_BIAS : {'all-reduce': 1, 'other': 0},
AttnBiasType.PRE_SCALE_BIAS : {'all-reduce': 2, 'other': 0},
AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 2, 'other': 0},
},
ShardingType.TP_COL : {
AttnBiasType.NO_BIAS : {'all-reduce': 1, 'other': 0},
AttnBiasType.PRE_SCALE_BIAS : {'all-reduce': 1, 'other': 0},
AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 1, 'other': 0}
},
ShardingType.DP_TP_COL : {
AttnBiasType.NO_BIAS : {'all-reduce': 1, 'other': 0},
AttnBiasType.PRE_SCALE_BIAS : {'all-reduce': 2, 'other': 0},
AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 2, 'other': 0}
},
}
self.self_attn_refs = [
mesh_config + (bias_type, self_attn_collectives[mesh_config[2]][bias_type]) \
for mesh_config, bias_type in product(mesh_configs, self.self_attn_bias_types) \
if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW)
]
self.cross_attn_mask_types = [
AttnMaskType.PADDING_MASK,
AttnMaskType.NO_MASK
]
self.cross_attn_refs = [
mesh_config + ({'all-reduce': 1, 'other': 0}, ) \
for mesh_config in mesh_configs \
if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW)
]
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import pytest
import numpy as np
from dataclasses import dataclass
from typing import Tuple
from enum import Enum
from functools import partial
import jax
import jax.numpy as jnp
from jax import random
from jax.experimental.pjit import pjit, _UNSPECIFIED
from jax.sharding import PartitionSpec
import flax
from transformer_engine.jax.sharding import ShardingType
try:
# try importing the new custom partitioning implementation
from transformer_engine.jax.sharding import MeshResource
except ImportError:
# must be using an older TE/JAX version so fall back on the xmap sharding implementation
MeshResource = None
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.fused_attn import \
AttnBiasType, AttnMaskType, is_fused_attn_kernel_available, self_fused_attn, cross_fused_attn
class FusedAttnBackend(Enum):
Max512 = "0"
Arbitrary = "1"
@pytest.fixture(name="backend", params=[FusedAttnBackend.Max512, FusedAttnBackend.Arbitrary])
def fixture_backend(request):
backend = request.param
os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value
yield backend
os.environ["NVTE_FUSED_ATTN_BACKEND"] = ""
@dataclass
class DistributedOpsHelper:
qkv_shape: Tuple[int,int,int,int] = (32, 128, 16, 64)
pad_ratio: float = 0.3
dropout_prob: float = 0.1
dtype: type = jnp.float16
@staticmethod
def use_custom_partitioning():
return (MeshResource is not None)
@staticmethod
def get_sharding_spec(mesh_names, sharding_type):
P = PartitionSpec
if sharding_type is ShardingType.DP:
return P(mesh_names[0], None), P(None), P(None)
elif sharding_type is ShardingType.DP_TP_COL:
return P(mesh_names[0], mesh_names[1]), P(None), P(None)
else:
raise NotImplementedError
@staticmethod
def get_sharding_resource(mesh_names, sharding_type):
dp_r = None
tp_r = None
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
dp_r = mesh_names[0]
if sharding_type in (ShardingType.TP_COL, ShardingType.TP_ROW):
tp_r = mesh_names[0]
if sharding_type in (ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
tp_r = mesh_names[1]
return MeshResource(dp_r, tp_r)
@staticmethod
def make_mask(q_tokens, kv_tokens, mask_type, dtype=jnp.uint8):
if mask_type == AttnMaskType.CAUSAL_MASK:
causal = flax.linen.make_causal_mask(q_tokens, dtype=dtype)
padding = flax.linen.make_attention_mask(q_tokens > 0, kv_tokens > 0, dtype=dtype)
return flax.linen.combine_masks(causal, padding)
else:
return flax.linen.make_attention_mask(q_tokens > 0, kv_tokens > 0, dtype=dtype)
@staticmethod
def count_collectives(hlo):
tmp = hlo.splitlines()
symb = "-start"
result = {
"all-reduce" : 0,
"other" : 0
}
for line in tmp:
txt = line.split()
if len(txt) > 0 and symb in txt[0]:
if "all-reduce" in txt[0]:
result["all-reduce"] += 1
else:
result["other"] += 1
return result
def get_tolerance(self, ref_val, relaxation=2./3., dtype=None):
if dtype is None:
dtype = self.dtype
# slightly relax the machine epsilon for minimum tolerance
eps_relaxed = jax.lax.pow(jnp.finfo(dtype).eps, dtype(relaxation))
# calculate the "Unit of Least Precision" -- i.e. distance to the next representable number
spacing_high = jnp.nextafter(dtype(ref_val), jnp.finfo(dtype).max) - dtype(ref_val)
spacing_low = dtype(ref_val) - jnp.nextafter(dtype(ref_val, jnp.finfo(dtype).min))
ulp = jax.lax.max(spacing_low, spacing_high)
return jax.lax.max(eps_relaxed, ulp)
def compare_ops(self, custom_func, ref_func, ref_count,
*args, grad_args=None, dtype=None,
in_shardings=_UNSPECIFIED, out_shardings=_UNSPECIFIED,
**kwargs):
if dtype is None:
dtype = self.dtype
if isinstance(custom_func, partial):
func_name = custom_func.func.__name__
else:
func_name = custom_func.__name__
func_name = func_name.removeprefix('custom_')
if grad_args is None:
grad_args = tuple(range(len(args)))
custom_gradded = jax.value_and_grad(custom_func, argnums=grad_args)
test_fwd, test_grads = custom_gradded(*args, **kwargs)
custom_pjitter = pjit(custom_gradded,
in_shardings=in_shardings,
out_shardings=out_shardings)
custom_hlo = custom_pjitter.lower(*args, **kwargs).compile().as_text()
custom_count = self.count_collectives(custom_hlo)
if ref_count is not None:
assert custom_count==ref_count, \
f"`{func_name}`: Expected collective count is {ref_count}, but got {custom_count}."
else:
print(f"`{func_name}`: Output collective count is {custom_count}.")
ref_gradded = jax.value_and_grad(ref_func, argnums=grad_args)
ref_fwd, ref_grads = ref_gradded(*args, **kwargs)
fwd_tol = self.get_tolerance(ref_fwd, dtype=dtype)
assert jnp.allclose(test_fwd, ref_fwd, rtol=0.0, atol=fwd_tol), \
f"`{func_name}`: Output (fwd) error {jnp.max(jnp.abs(test_fwd - ref_fwd))}" + \
f" exceeds tolerance ({fwd_tol})."
if len(grad_args) == 1:
ref_grads = (ref_grads, )
test_grads = (test_grads, )
failed_grads = {}
for i, grads in enumerate(zip(test_grads, ref_grads)):
test_grad, ref_grad = grads
if test_grad is None and ref_grad is None:
continue
bwd_tol = self.get_tolerance(ref_grad, dtype=dtype)
if not jnp.allclose(test_grad, ref_grad, rtol=0.0, atol=bwd_tol):
failed_grads[i] = jnp.max(jnp.abs(test_grad - ref_grad))
assert len(failed_grads) == 0, \
f"`{func_name}`: Gradient (bwd) max errors" + \
f" [{', '.join([f'Arg{k}={v}' for k,v in failed_grads.items()])}]" + \
f" exceed tolerance ({bwd_tol})."
@staticmethod
def check_fused_attn_inputs(self, q_seq, kv_seq, head_dim, pad_ratio, dropout_probability,
attn_bias_type, attn_mask_type, backend, dtype=jnp.float16):
if (q_seq > 512 or kv_seq > 512 or backend == FusedAttnBackend.Arbitrary) \
and pad_ratio != 0:
pytest.skip(
"`fused_attention`: Arbitrary seqlen backend does not support padded input.")
if not is_fused_attn_kernel_available(
dtype, dtype, attn_bias_type, attn_mask_type,
dropout_probability, q_seq, kv_seq, head_dim):
pytest.skip(
"`fused_attention`: Unsupported inputs combination or device compute capability.")
def fused_attn_core(self, query, key, value, bias, mask, scale_factor,
attn_bias_type, attn_mask_type, dropout_rng, dropout_prob):
# Q*K matmul
query = jnp.squeeze(query)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key)
# scale and bias
if attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights = scale_factor * (attn_weights + bias)
elif attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
attn_weights = scale_factor * attn_weights + bias
else:
attn_weights = scale_factor * attn_weights
# padding mask
if attn_mask_type != AttnMaskType.NO_MASK and mask is not None:
big_neg = jnp.finfo(query.dtype).min
attn_weights = jnp.where(mask, attn_weights, big_neg)
# softmax
attn_weights = jax.nn.softmax(attn_weights).astype(query.dtype)
# dropout
if dropout_prob == 1.0:
attn_weights = jnp.zeros_like(attn_weights)
elif dropout_prob > 0.0:
keep_prob = 1.0 - dropout_prob
keep = random.bernoulli(dropout_rng, p=keep_prob, shape=attn_weights.shape)
multiplier = keep.astype(query.dtype) / jnp.asarray(keep_prob, dtype=query.dtype)
attn_weights = attn_weights * multiplier
# QK*V matmul
result = jnp.einsum('...hqk,...khd->...qhd', attn_weights, value)
return jnp.mean(result)
@staticmethod
def custom_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, sharding_type):
result = layernorm(x, gamma, beta,
layernorm_type='layernorm',
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_dim_index=0)
return jnp.mean(result)
def reference_layernorm(self, x, gamma, beta, zero_centered_gamma, epsilon):
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon)
if zero_centered_gamma:
result = jnp.asarray(normed_input * (gamma + 1) + beta).astype(x.dtype)
else:
result = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(result)
@staticmethod
def custom_rmsnorm(x, gamma, epsilon, sharding_type):
result = layernorm(x, gamma, None,
layernorm_type='rmsnorm',
zero_centered_gamma=False,
epsilon=epsilon,
sharding_type=sharding_type,
dp_dim_index=0)
return jnp.mean(result)
def reference_rmsnorm(self, x, gamma, epsilon):
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), x.dtype)
result = y * gamma
return jnp.mean(result)
@staticmethod
def custom_softmax(x, mask, scale_factor, softmax_type, sharding_type):
result = softmax(x, mask,
scale_factor=scale_factor,
softmax_type=softmax_type,
sharding_type=sharding_type)
return jnp.mean(result)
def reference_softmax(self, x, mask, scale_factor, softmax_type):
attn_weights = scale_factor * x
if softmax_type != SoftmaxType.SCALED:
big_neg = jnp.finfo(x.dtype).min
attn_weights = jnp.where(mask, attn_weights, big_neg)
result = jax.nn.softmax(attn_weights).astype(x.dtype)
return jnp.mean(result)
@staticmethod
def custom_self_fused_attn(qkv, bias, mask, rng_key, dropout_prob,
attn_bias_type, attn_mask_type,
scaling_factor, sharding_type):
mask = (mask == 0) # invert mask
bias_ = None if attn_bias_type == AttnBiasType.NO_BIAS else bias
result = self_fused_attn(qkv, bias_, mask,
seed=rng_key,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=True,
sharding_type=sharding_type)
return jnp.mean(result)
def reference_self_fused_attn(self, qkv, bias, mask, rng_key, dropout_prob,
attn_bias_type, attn_mask_type,
scaling_factor):
# split interleaved QKV into separate matrices
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
return self.fused_attn_core(
query, key, value, bias, mask, scaling_factor,
attn_bias_type, attn_mask_type,
rng_key, dropout_prob)
@staticmethod
def custom_cross_fused_attn(query, key_value, mask, rng_key, dropout_prob,
attn_mask_type, scaling_factor, sharding_type):
mask = (mask == 0) # invert mask
result = cross_fused_attn(query, key_value, mask,
seed=rng_key,
attn_bias_type=AttnBiasType.NO_BIAS,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=True,
sharding_type=sharding_type)
return jnp.mean(result)
def reference_cross_fused_attn(self, query, key_value, mask, rng_key, dropout_prob,
attn_mask_type, scaling_factor):
key, value = jnp.split(key_value, [1], axis=-3)
return self.fused_attn_core(
query, key, value, None, mask, scaling_factor,
AttnBiasType.NO_BIAS, attn_mask_type,
rng_key, dropout_prob)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import operator
import re
from functools import reduce
import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED
from transformer_engine.jax.sharding import MeshResource
from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')])
if is_devices_enough(4):
TP_size = 2
DP_size = 2
configs.append(
[4, (DP_size, TP_size), ('dp', 'tp'),
MeshResource(dp_resource='dp', tp_resource='tp')])
return configs
COLL_AR_KEY = "all-reduce"
COLL_AG_KEY = "all-gather"
COLL_OTHER_KEY = "other"
def generate_collectives_count(allreduce, allgather, other):
return {COLL_AR_KEY: allreduce, COLL_AG_KEY: allgather, COLL_OTHER_KEY: other}
def assert_equal_collectives(target_hlo, coll_count_ref):
target_splitted_hlo = target_hlo.splitlines()
start_symb = "-start"
def count_bytes(hlo_text):
bytes_count = 0
def get_bytes_per_txt(t):
'''
The pattern of t would be like:
'f32[]',
'(f32[1024]{0}',
'f32[1024]{0})',
'f8E4M3FN[1024]{0}',
'i32[1024]{0}',
'bf16[1024,1024]{0}'
'''
match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t)
_, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8
if shape == '':
num_of_elements = 1
else:
num_of_elements = reduce(operator.mul, map(int, shape.split(',')))
return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
if '(' in hlo_text[2]:
for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt)
if ')' in txt:
break
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2])
return bytes_count
def count_collectives(splitted_hlo):
result = generate_collectives_count(0, 0, 0)
for line in splitted_hlo:
txt = line.split()
if len(txt) > 0 and start_symb in txt[0]:
if COLL_AR_KEY in txt[0]:
result[COLL_AR_KEY] += count_bytes(txt)
elif COLL_AG_KEY in txt[0]:
result[COLL_AG_KEY] += count_bytes(txt)
else:
result[COLL_OTHER_KEY] += count_bytes(txt)
return result
target_result = count_collectives(target_splitted_hlo)
assert target_result == coll_count_ref, \
f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(target_func,
ref_func,
inputs,
coll_count_ref,
*,
grad_args=None,
metric_fwd_dtype=None,
metric_bwd_dtype=None,
in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
**kwargs):
assert len(inputs) >= 1
if metric_fwd_dtype is None:
metric_fwd_dtype = inputs[0].dtype
if metric_bwd_dtype is None:
metric_bwd_dtype = inputs[0].dtype
if grad_args is None:
grad_args = tuple(range(len(inputs)))
target_grad_func = jax.value_and_grad(target_func, argnums=grad_args)
target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
target_fwd, target_grads = target_pjitter(*inputs, **kwargs)
target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text()
ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args)
ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs)
assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype)
for i in range(len(target_grads)):
assert_allclose(target_grads[i], ref_grads[i], dtype=metric_bwd_dtype)
assert_equal_collectives(target_hlo, coll_count_ref)
This diff is collapsed.
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import jax
import jax.numpy as jnp
from jax.core import ShapedArray
from transformer_engine_jax import DType
from transformer_engine.jax.cpp_extensions import te_dtype_to_jax_dtype
from transformer_engine.jax.cpp_extensions import GemmPrimitive
SHAPES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
NAMED_SHAPES = [{}, {
"data": 4
}, {
"data": 2
}, {
"model": 4
}, {
"model": 2
}, {
"data": 4,
"model": 2
}, {
"model": 4,
"data": 2
}]
DTYPE = [DType.kFloat32, DType.kFloat16, DType.kBFloat16]
TRANSPOSE = [True, False]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
class TestGEMMShapeInfer:
@staticmethod
def _joint_named_shape(ns1, ns2):
output_named_shape = {**ns1}
need_assert = False
for key in ns2:
if key in output_named_shape and output_named_shape[key] != ns2[key]:
need_assert = True
else:
output_named_shape[key] = ns2[key]
return output_named_shape, need_assert
@staticmethod
def _get_shapes(m, n, k, transa, transb):
# te_gemm only support TN and col-major, then we have to reorder a, b shape
# to compute row-major matrices calculate in col-major algos.
a = (m, k) if transa else (k, m)
b = (k, n) if transb else (n, k)
out = (n, m)
return a, b, out
@pytest.mark.parametrize('shapes', SHAPES)
@pytest.mark.parametrize('named_shape1', NAMED_SHAPES)
@pytest.mark.parametrize('named_shape2', NAMED_SHAPES)
@pytest.mark.parametrize('te_dtype', DTYPE)
@pytest.mark.parametrize('transa', TRANSPOSE)
@pytest.mark.parametrize('transb', TRANSPOSE)
def test_shape_infer(self, shapes, named_shape1, named_shape2, te_dtype, transa, transb):
a_shape, b_shape, out_shape = TestGEMMShapeInfer._get_shapes(*shapes, transa, transb)
dtype = te_dtype_to_jax_dtype(te_dtype)
mat_a = ShapedArray(a_shape, dtype, named_shape=named_shape1)
mat_b = ShapedArray(b_shape, dtype, named_shape=named_shape2)
scale_inv_a = ShapedArray((3, 1), jnp.float32)
scale_inv_b = ShapedArray((3, 1), jnp.float32)
ref_out_named_shape, need_assert = TestGEMMShapeInfer._joint_named_shape(
named_shape1, named_shape2)
ref_out = ShapedArray(out_shape, dtype, named_shape=ref_out_named_shape)
try:
test_out = GemmPrimitive.abstract(mat_a,
mat_b,
scale_inv_a,
scale_inv_b,
A_dtype=te_dtype,
B_dtype=te_dtype,
D_dtype=te_dtype,
transa=transa,
transb=transb,
use_split_accumulator=False)
assert not need_assert
assert ref_out == test_out
except AssertionError as ae:
assert need_assert, f"{ae.args}"
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import numpy as np
from functools import partial
import jax
import jax.numpy as jnp
from jax import random
from jax.sharding import NamedSharding
from utils import is_devices_enough
from distributed_configs_helper import *
from distributed_ops_helper import *
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
configs = DistributedConfigsHelper() # default device count is len(jax.devices())
ops = DistributedOpsHelper() # default data type is jnp.float16
@pytest.mark.skipif(not is_devices_enough(configs.device_count),
reason='Insufficient number of GPUs, need at least 2.')
@pytest.mark.skipif(not ops.use_custom_partitioning(),
reason='TE/JAX version does not support sharding with ' + \
'jax.experimental.custom_partitioning.')
class TestCustomPartitioningOpsGenerator:
@pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
configs.layernorm_refs)
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
def test_layernorm(self, mesh_shape, mesh_names, sharding_type, collective_ref,
zero_centered_gamma):
epsilon = 1e-6
custom_func = partial(ops.custom_layernorm,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type)
reference_func = partial(ops.reference_layernorm,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
batch_size, _, num_heads, head_dim = ops.qkv_shape
hidden_size = num_heads*head_dim
input_shape = (batch_size, hidden_size)
other_shape = (hidden_size, )
x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype)
gamma_ = jnp.ones(other_shape, dtype=ops.dtype)
beta_ = jnp.ones(other_shape, dtype=ops.dtype)
x_spec, gamma_spec, beta_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
beta_ = jax.device_put(beta_, NamedSharding(mesh, beta_spec))
ops.compare_ops(
custom_func, reference_func, collective_ref,
x_, gamma_, beta_, grad_args=(0, 1, 2), dtype=ops.dtype,
in_shardings=[x_spec, gamma_spec, beta_spec],
out_shardings=(None, (x_spec, gamma_spec, beta_spec))
)
@pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
configs.layernorm_refs)
def test_rmsnorm(self, mesh_shape, mesh_names, sharding_type, collective_ref):
epsilon = 1e-6
custom_func = partial(ops.custom_rmsnorm, epsilon=epsilon,sharding_type=sharding_type)
reference_func = partial(ops.reference_rmsnorm, epsilon=epsilon)
batch_size, _, num_heads, head_dim = ops.qkv_shape
hidden_size = num_heads*head_dim
input_shape = (batch_size, hidden_size)
other_shape = (hidden_size, )
x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype)
gamma_ = jnp.ones(other_shape, dtype=ops.dtype)
x_spec, gamma_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
ops.compare_ops(
custom_func, reference_func, collective_ref,
x_, gamma_, grad_args=(0, 1), dtype=ops.dtype,
in_shardings=[x_spec, gamma_spec],
out_shardings=(None, (x_spec, gamma_spec))
)
@pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
configs.softmax_refs)
@pytest.mark.parametrize('softmax_type', configs.softmax_types)
def test_softmax(self, mesh_shape, mesh_names, sharding_type, collective_ref,
softmax_type):
batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
scale_factor = 1./jnp.sqrt(head_dim)
custom_func = partial(ops.custom_softmax,
scale_factor=scale_factor,
softmax_type=softmax_type,
sharding_type=sharding_type)
reference_func = partial(ops.reference_softmax,
scale_factor=scale_factor,
softmax_type=softmax_type)
input_size = (batch_size, num_heads, seq_len, seq_len)
x_ = random.normal(random.PRNGKey(1124), input_size, dtype=ops.dtype)
pad_len = int(seq_len * ops.pad_ratio)
valid_len = seq_len - pad_len
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
axis=-1)
mask_ = ops.make_mask(tokens, tokens, AttnMaskType.PADDING_MASK)
x_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
ops.compare_ops(
custom_func, reference_func, collective_ref,
(0), x_, mask_, grad_args=(0), dtype=ops.dtype,
in_shardings=[x_spec, mask_spec],
out_shardings=(None, (x_spec))
)
@pytest.mark.parametrize(
'mesh_shape, mesh_names, sharding_type, attn_bias_type, collective_ref',
configs.self_attn_refs)
@pytest.mark.parametrize('attn_mask_type', configs.self_attn_mask_types)
def test_self_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
attn_bias_type, attn_mask_type, backend):
batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
ops.check_fused_attn_inputs(seq_len, seq_len, head_dim,
ops.pad_ratio, ops.dropout_prob,
attn_bias_type, attn_mask_type, backend)
dropout_rng = random.PRNGKey(91023051)
split_rng = random.split(dropout_rng, configs.device_count)
scale_factor = 1./jnp.sqrt(head_dim)
custom_func = partial(ops.custom_self_fused_attn,
rng_key=split_rng,
dropout_prob=ops.dropout_prob,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
sharding_type=sharding_type)
reference_func = partial(ops.reference_self_fused_attn,
rng_key=dropout_rng,
dropout_prob=ops.dropout_prob,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor)
key = random.PRNGKey(1124)
subkeys = random.split(key, 2)
qkv_shape = (batch_size, seq_len, 3, num_heads, head_dim)
qkv_ = random.normal(subkeys[0], qkv_shape, dtype=ops.dtype)
bias_shape = (1, num_heads, seq_len, seq_len)
bias_ = random.normal(subkeys[1], bias_shape, dtype=ops.dtype)
pad_len = int(seq_len * ops.pad_ratio)
valid_len = seq_len - pad_len
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
axis=-1)
mask_ = ops.make_mask(tokens, tokens, attn_mask_type)
qkv_spec, bias_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
qkv_ = jax.device_put(qkv_, NamedSharding(mesh, qkv_spec))
bias_ = jax.device_put(bias_, NamedSharding(mesh, bias_spec))
mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
ops.compare_ops(
custom_func, reference_func, collective_ref,
qkv_, bias_, mask_, grad_args=(0, 1), dtype=ops.dtype,
in_shardings=[qkv_spec, bias_spec, mask_spec],
out_shardings=(None, (qkv_spec, bias_spec))
)
@pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
configs.cross_attn_refs)
@pytest.mark.parametrize('attn_mask_type', configs.cross_attn_mask_types)
def test_cross_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
attn_mask_type, backend):
batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
ops.check_fused_attn_inputs(seq_len, seq_len, head_dim,
ops.pad_ratio, ops.dropout_prob,
AttnBiasType.NO_BIAS, attn_mask_type, backend)
dropout_rng = random.PRNGKey(91023051)
split_rng = random.split(dropout_rng, configs.device_count)
scale_factor = 1./jnp.sqrt(head_dim)
custom_func = partial(ops.custom_cross_fused_attn,
rng_key=split_rng,
dropout_prob=ops.dropout_prob,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
sharding_type=sharding_type)
reference_func = partial(ops.reference_cross_fused_attn,
rng_key=split_rng,
dropout_prob=ops.dropout_prob,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor)
key = random.PRNGKey(1124)
subkeys = random.split(key, 2)
q_shape = (batch_size, seq_len, num_heads, head_dim)
q_ = random.normal(subkeys[0], q_shape, dtype=ops.dtype)
kv_shape = (batch_size, seq_len, 2, num_heads, head_dim)
kv_ = random.normal(subkeys[1], kv_shape, dtype=ops.dtype)
pad_len = int(seq_len * ops.pad_ratio)
valid_len = seq_len - pad_len
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
axis=-1)
mask_ = ops.make_mask(tokens, tokens, attn_mask_type)
q_spec, kv_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
q_ = jax.device_put(q_, NamedSharding(mesh, q_spec))
kv_= jax.device_put(kv_, NamedSharding(mesh, kv_spec))
mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
ops.compare_ops(
custom_func, reference_func, collective_ref,
q_, kv_, mask_, grad_args=(0, 1), dtype=ops.dtype,
in_shardings=[q_spec, kv_spec, mask_spec],
out_shardings=(None, (q_spec, kv_spec))
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSelfAttn:
def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape,
dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
tp_size = 1
if mesh_resource.tp_resource is not None:
idx = mesh_axes.index(mesh_resource.tp_resource)
tp_size = mesh_shape[idx]
all_reduce_loss_bytes = 4 # 1 * FP32
bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
# for loss and dbias
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype):
batch, seqlen, _, heads, _ = shape
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \
if with_bias else None
mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK:
mask = make_causal_mask(batch, seqlen)
elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen)
qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
None)
bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \
if with_bias else None
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
if attn_mask_type != AttnMaskType.NO_MASK else None
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize(
'attn_bias_type',
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
attn_bias_type, attn_mask_type, dtype):
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
_, seqlen, _, _, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_prob, seqlen, seqlen, hidden):
pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask):
return jnp.mean(
self_fused_attn(qkv,
bias,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
query = jnp.squeeze(query)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=bias,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32)
return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, with_bias,
attn_mask_type, dtype)
collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes,
mesh_resource, with_bias,
data_shape, dtype)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \
if bias is not None else bias
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
if mask is not None else mask
grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
compare_ops(target_func,
ref_func, [qkv_, bias_, mask_],
collective_count_ref,
grad_args=grad_args,
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings))
class TestDistributedCrossAttn:
def generate_collectives_count_ref(self):
# for loss
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
batch, seqlen, heads, hidden = shape
q = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype)
mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK:
mask = make_causal_mask(batch, seqlen)
elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen)
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
None)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
if attn_mask_type != AttnMaskType.NO_MASK else None
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
attn_mask_type, dtype):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
_, seqlen, _, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_prob, seqlen, seqlen, hidden):
pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask):
return jnp.mean(
cross_fused_attn(q,
kv,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3)
query = jnp.squeeze(query)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32)
return jnp.mean(output).astype(dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
if mask is not None else mask
compare_ops(target_func,
ref_func, [q_, kv_, mask_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)))
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.layernorm import layernorm
DTYPES = [jnp.bfloat16, jnp.float32]
class TestDistributedLayernorm:
def generate_inputs(self, shape, mesh_resource, dtype):
weight_shape = (shape[-1],)
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
gamma = jnp.ones(weight_shape, dtype=dtype)
beta = jnp.ones(weight_shape, dtype=dtype)
if len(shape) == 2:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None)
elif len(shape) == 3:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None, None)
else:
raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(None)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ['layernorm', 'rmsnorm']
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
weight_count = 2 if ln_type == 'layernorm' else 1
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled),
allgather=0,
other=0)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype,
zero_centered_gamma):
epsilon = 1e-6
ln_type = 'layernorm'
def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
def ref_func(x, gamma, beta):
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon)
if zero_centered_gamma:
output = jnp.asarray(normed_input * (gamma + 1) + beta).astype(x.dtype)
else:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \
self.generate_inputs(data_shape, mesh_resource, dtype)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
compare_ops(target_func,
ref_func, [x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype):
epsilon = 1e-6
ln_type = 'rmsnorm'
def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
def ref_func(x, gamma):
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), dtype)
output = y * gamma
return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \
self.generate_inputs(data_shape, mesh_resource, dtype)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
compare_ops(target_func,
ref_func, [x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)))
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax
DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSoftmax:
def generate_collectives_count_ref(self):
# for loss
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype):
batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen)
else:
mask = make_self_mask(batch, sqelen)
x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource, None, None)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
'softmax_type',
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED])
@pytest.mark.parametrize('scale_factor', [1.0, 3.0])
@pytest.mark.parametrize('dtype', DTYPES)
def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
softmax_type, scale_factor, dtype):
def target_func(x, mask):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
def ref_func(x, mask):
bias = None
if mask is not None:
bias = jax.lax.select(mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.).astype(dtype))
if bias is not None:
x = x + bias.astype(dtype)
output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output)
(x, mask), (x_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
compare_ops(target_func,
ref_func, [x_, mask_],
collective_count_ref,
grad_args=(0,),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)))
...@@ -14,9 +14,7 @@ from transformer_engine.common.recipe import DelayedScaling ...@@ -14,9 +14,7 @@ from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available, AmaxComputeAlgo from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available, AmaxComputeAlgo
from transformer_engine.jax.sharding import infer_major_sharding_type from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax.sharding import MajorShardingType
from transformer_engine.jax.sharding import ShardingResource
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -160,7 +158,6 @@ class TestFP8Functions(unittest.TestCase): ...@@ -160,7 +158,6 @@ class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self): def _check_defult_state(self):
self.assertFalse(FP8Helper.is_fp8_enabled()) self.assertFalse(FP8Helper.is_fp8_enabled())
self.assertEqual(infer_major_sharding_type(), MajorShardingType.SINGLE)
def _compare_delay_scaling(self, ref, test): def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin) self.assertTrue(ref.margin == test.margin)
...@@ -201,27 +198,20 @@ class TestFP8Functions(unittest.TestCase): ...@@ -201,27 +198,20 @@ class TestFP8Functions(unittest.TestCase):
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
# TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme mesh_s = (
# srs = ( (MeshResource(None, None)),
# (ShardingResource(None, None), MajorShardingType.SINGLE), (MeshResource('dp', None)),
# (ShardingResource('dp', None), MajorShardingType.DP), (MeshResource(None, 'tp')),
# (ShardingResource(None, 'tp'), MajorShardingType.TP), (MeshResource('dp', 'tp')),
# (ShardingResource('dp', 'tp'), MajorShardingType.DPTP),
# )
srs = (
(ShardingResource(None, None), MajorShardingType.SINGLE),
(ShardingResource('dp', None), MajorShardingType.SINGLE),
(ShardingResource(None, 'tp'), MajorShardingType.SINGLE),
(ShardingResource('dp', 'tp'), MajorShardingType.SINGLE),
) )
# TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme # TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1) mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with jax.sharding.Mesh(devices, ('dp', 'tp')): with jax.sharding.Mesh(devices, ('dp', 'tp')):
for sr, mst in srs: for sr in mesh_s:
with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr): with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
self.assertTrue(FP8Helper.is_fp8_enabled()) self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(infer_major_sharding_type(), mst) self.assertEqual(sr, global_mesh_resource())
self._check_defult_state() self._check_defult_state()
...@@ -2,40 +2,10 @@ ...@@ -2,40 +2,10 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import jax
import numpy as np
import pytest import pytest
from utils import is_devices_enough
from transformer_engine.jax.flax import extend_logical_axis_rules from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import get_dot_sharding_meta from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from transformer_engine.jax.sharding import get_elementwise_sharding_meta
from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled
from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType
def _get_sharding_resource(mesh_names, sharding_type):
dp_r = None
tp_r = None
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
dp_r = mesh_names[0]
if sharding_type in (ShardingType.TP_COL, ShardingType.TP_ROW):
tp_r = mesh_names[0]
if sharding_type in (ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
tp_r = mesh_names[1]
return ShardingResource(dp_r, tp_r)
DEVICE_COUNT = 4
MESH_CONFIG = [((4,), ("dp",), ShardingType.DP), ((4,), ("tp",), ShardingType.TP_COL),
((4,), ("tp",), ShardingType.TP_ROW), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_COL),
((2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW)]
LOGICAL_RULES = [ LOGICAL_RULES = [
[(('a1', None), ('a2', 'ma2')), False], [(('a1', None), ('a2', 'ma2')), False],
...@@ -44,18 +14,19 @@ LOGICAL_RULES = [ ...@@ -44,18 +14,19 @@ LOGICAL_RULES = [
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True], [(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True],
[(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True], [(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True],
] ]
SRS = [
ShardingResource(), MeshS = [
ShardingResource('data', None), MeshResource(),
ShardingResource(None, 'model'), MeshResource('data', None),
ShardingResource('data', 'model') MeshResource(None, 'model'),
MeshResource('data', 'model')
] ]
class TestShardingSideAPI: class TestShardingSideAPI:
@pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES) @pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES)
@pytest.mark.parametrize('sr', SRS) @pytest.mark.parametrize('sr', MeshS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr): def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr): with global_shard_guard(sr):
try: try:
...@@ -65,270 +36,3 @@ class TestShardingSideAPI: ...@@ -65,270 +36,3 @@ class TestShardingSideAPI:
assert not need_assert assert not need_assert
except AssertionError as ae: except AssertionError as ae:
assert need_assert, f"{ae.args}" assert need_assert, f"{ae.args}"
class TestGeneralFunc:
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_infer_major_sharding_type(
self,
mesh_shape, # pylint: disable=unused-argument
mesh_names,
sharding_type):
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with jax.sharding.Mesh(devices, mesh_names):
assert infer_major_sharding_type() is sharding_type.value[0]
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
def test_is_dp_enabled(
self,
mesh_shape, # pylint: disable=unused-argument
mesh_names, # pylint: disable=unused-argument
sharding_type):
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
assert is_dp_enabled(sharding_type.value[0])
else:
assert not is_dp_enabled(sharding_type.value[0])
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
def test_is_tp_enabled(
self,
mesh_shape, # pylint: disable=unused-argument
mesh_names, # pylint: disable=unused-argument
sharding_type):
if sharding_type is ShardingType.DP:
assert not is_tp_enabled(sharding_type.value[0])
else:
assert is_tp_enabled(sharding_type.value[0])
class TestShardingMetaGenerator:
BATCH_AXIS_NAME = 'batch'
MODEL_AXIS_NAME = 'model'
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4):
def stack_axes_meta(mapping):
return tuple(mapping for _ in range(num_of_fp8_meta))
def get_ref_sm():
if sharding_type == ShardingType.DP:
return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
{TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]}, (),
())
if sharding_type == ShardingType.TP_COL:
return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
{TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (),
())
if sharding_type == ShardingType.TP_ROW:
return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
{TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (),
())
if sharding_type == ShardingType.DP_TP_COL:
return ShardingMeta(
stack_axes_meta({}), stack_axes_meta({}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, (), ())
if sharding_type == ShardingType.DP_TP_ROW:
return ShardingMeta(
stack_axes_meta({}), stack_axes_meta({}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, (), ())
return None
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with jax.sharding.Mesh(devices, mesh_names):
test_sm = get_fp8_meta_sharding_meta(
sharding_type,
num_of_fp8_meta,
dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
assert test_sm == get_ref_sm()
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)),
((128, 64, 512), (512, 256))])
@pytest.mark.parametrize('batch_dim_of_a', [0, 1])
@pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a):
model_dim_of_a = len(a_shape) - 1
model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1
contracting_dims = ((-1,), (0,))
def get_ref_sm():
out_shape = (*a_shape[:min(contracting_dims[0])],
*b_shape[max(contracting_dims[1]) + 1:])
if sharding_type == ShardingType.DP:
a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0], -1,
*a_shape[batch_dim_of_a + 1:])
return ShardingMeta(({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}, {}), ({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]},
[a_new_shape, b_shape], [out_shape])
if sharding_type == ShardingType.TP_COL:
b_new_shape = (b_shape[0], mesh_shape[0], b_shape[1] // mesh_shape[0])
return ShardingMeta(({}, {
1: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
len(out_shape) - 1: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
[a_shape, b_new_shape], [out_shape])
if sharding_type == ShardingType.TP_ROW:
a_new_shape = (*a_shape[:-1], mesh_shape[0], a_shape[-1] // mesh_shape[0])
b_new_shape = (mesh_shape[0], b_shape[0] // mesh_shape[0], b_shape[1])
return ShardingMeta(({
len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
[a_new_shape, b_new_shape], [out_shape])
if sharding_type == ShardingType.DP_TP_COL:
a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0],
a_shape[batch_dim_of_a] // mesh_shape[0],
*a_shape[batch_dim_of_a + 1:])
b_new_shape = (b_shape[0], mesh_shape[1], b_shape[1] // mesh_shape[1])
return ShardingMeta(
({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}, {
1: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(out_shape): TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, [a_new_shape, b_new_shape], [out_shape])
if sharding_type == ShardingType.DP_TP_ROW:
a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0],
a_shape[batch_dim_of_a] // mesh_shape[0],
*a_shape[batch_dim_of_a + 1:-1], mesh_shape[1],
a_shape[-1] // mesh_shape[1])
b_new_shape = (mesh_shape[1], b_shape[0] // mesh_shape[1], b_shape[1])
return ShardingMeta(
({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, [a_new_shape, b_new_shape], [out_shape])
return None
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with jax.sharding.Mesh(devices, mesh_names):
test_sm = get_dot_sharding_meta(
sharding_type,
a_shape,
b_shape,
batch_dim_of_a,
model_dim_of_a,
model_dim_of_b,
contracting_dims,
dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
assert test_sm == get_ref_sm()
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)])
@pytest.mark.parametrize('other_shape', [(256,), (512,)])
@pytest.mark.parametrize('batch_dim', [0, 1])
@pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape,
batch_dim):
def get_ref_sm():
need_assert = True
ref_sharding_meta = None
if input_shape[-1] != other_shape[0]:
need_assert = True
ref_sharding_meta = None
elif sharding_type is (ShardingType.DP_TP_COL, ShardingType.DP):
need_assert = False
input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1,
*input_shape[batch_dim + 1:])
ref_sharding_meta = ShardingMeta(({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME
}, {}), ({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME
}), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]},
[input_new_shape, other_shape], [input_shape])
elif sharding_type is ShardingType.TP_COL:
need_assert = False
ref_sharding_meta = ShardingMeta(({}, {}), ({}), {}, [input_shape, other_shape],
[input_shape])
elif sharding_type is ShardingType.TP_ROW:
need_assert = False
input_new_shape = (*input_shape[:-1], mesh_shape[0], -1)
other_new_shape = (mesh_shape[0], -1)
ref_sharding_meta = ShardingMeta(({
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
[input_new_shape, other_new_shape], [input_shape])
elif sharding_type is ShardingType.DP_TP_ROW:
need_assert = False
input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1,
*input_shape[batch_dim + 1:-1], mesh_shape[1],
input_shape[-1] // mesh_shape[1])
other_new_shape = (mesh_shape[0], -1)
ref_sharding_meta = ShardingMeta(
({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, [input_new_shape, other_new_shape], [input_shape])
return ref_sharding_meta, need_assert
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with jax.sharding.Mesh(devices, mesh_names):
ref_sm, need_assert = get_ref_sm()
try:
test_sm = get_elementwise_sharding_meta(
sharding_type,
input_shape,
other_shape,
batch_dim,
dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
assert not need_assert
assert test_sm == ref_sm
except (NotImplementedError, AssertionError) as e:
assert need_assert, f"{e.args}"
...@@ -26,6 +26,7 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci ...@@ -26,6 +26,7 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci
lax.Precision]] lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array] Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_devices_enough(required): def is_devices_enough(required):
return len(jax.devices()) >= required return len(jax.devices()) >= required
...@@ -1010,6 +1011,24 @@ class DecoderLayer(nn.Module): ...@@ -1010,6 +1011,24 @@ class DecoderLayer(nn.Module):
return z return z
def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
shape = (batch, seqlen)
idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)
mask = jnp.greater_equal(jnp.expand_dims(idxs, axis=-1), jnp.expand_dims(idxs, axis=-2))
mask = jnp.expand_dims(mask, axis=-3)
mask = 1 - mask
return mask.astype(dtype)
def make_self_mask(batch, seqlen, dtype=jnp.uint8):
shape = (batch, seqlen)
mask = jnp.ones((*shape, shape[-1]))
mask = jnp.expand_dims(mask, axis=-3)
mask = 1 - mask
return mask.astype(dtype)
def assert_allclose( def assert_allclose(
actual: Array, actual: Array,
desired: Array, desired: Array,
...@@ -1092,7 +1111,7 @@ def dtype_tols( ...@@ -1092,7 +1111,7 @@ def dtype_tols(
# Estimate floating-point error # Estimate floating-point error
finfo = jnp.finfo(dtype) finfo = jnp.finfo(dtype)
eps_relaxed = math.pow(finfo.eps, 2/3) eps_relaxed = math.pow(finfo.eps, 2 / 3)
with jax.default_device(jax.devices("cpu")[0]): with jax.default_device(jax.devices("cpu")[0]):
if isinstance(reference_value, (float, int)): if isinstance(reference_value, (float, int)):
reference_value = jnp.array(reference_value, dtype=dtype) reference_value = jnp.array(reference_value, dtype=dtype)
......
...@@ -5,10 +5,30 @@ ...@@ -5,10 +5,30 @@
from . import flax from . import flax
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType from .sharding import MajorShardingType, ShardingResource, ShardingType
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum
MajorShardingType = DeprecatedEnum(MajorShardingType,
"MajorShardingType is deprecating in the near feature.")
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
ShardingResource,
"ShardingResource is renamed to MeshResource, and will be removed in the near feature.")
__all__ = [ __all__ = [
'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling', 'NVTE_FP8_COLLECTION_NAME',
'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis', 'fp8_autocast',
'update_collections',
'update_fp8_metas',
'get_delayed_scaling',
'MeshResource',
'MajorShardingType',
'ShardingResource',
'ShardingType',
'flax',
'praxis',
] ]
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment