Unverified Commit 30cad990 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Bugfix for insufficient GPUs crash in distributed ops (#505)



* Fixed minor bug with DistributedConfigsHelper prematurely crashing the test for insufficient GPUs before @pytest.skip condition.
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Update tests/jax/distributed_configs_helper.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug PyTest errors when running on single-GPU system
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 74eb7c33
......@@ -8,14 +8,22 @@ from transformer_engine.jax.softmax import SoftmaxType
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType\
class ShardingConfigs(object):
class DistributedConfigsHelper(object):
def __init__(self, num_gpus=len(jax.devices())):
super().__init__()
if num_gpus < 2:
raise ValueError(f"ShardingConfig: Need at least 2 GPUs.")
self.device_count = min(num_gpus, 8)
if self.device_count < 2:
self.layernorm_refs = []
self.softmax_types = []
self.softmax_refs = []
self.self_attn_bias_types = []
self.self_attn_mask_types = []
self.self_attn_refs = []
self.cross_attn_mask_types = []
self.cross_attn_refs = []
return
mesh_configs = [
((self.device_count, 1), ("dp", None), ShardingType.DP),
((self.device_count, 1), ("tp", None), ShardingType.TP_COL),
......
......@@ -44,7 +44,7 @@ def fixture_backend(request):
@dataclass
class CustomOpsTestHelper:
class DistributedOpsHelper:
qkv_shape: Tuple[int,int,int,int] = (32, 128, 16, 64)
pad_ratio: float = 0.3
dropout_prob: float = 0.1
......
......@@ -11,18 +11,19 @@ from jax import random
from jax.sharding import NamedSharding
from utils import is_devices_enough
from sharding_configs import *
from custom_ops_helper import *
from distributed_configs_helper import *
from distributed_ops_helper import *
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
configs = ShardingConfigs() # default device count is len(jax.devices())
helper = CustomOpsTestHelper()
configs = DistributedConfigsHelper() # default device count is len(jax.devices())
ops = DistributedOpsHelper() # default data type is jnp.float16
@pytest.mark.skipif(not helper.use_custom_partitioning(),
@pytest.mark.skipif(not is_devices_enough(configs.device_count),
reason='Insufficient number of GPUs, need at least 2.')
@pytest.mark.skipif(not ops.use_custom_partitioning(),
reason='TE/JAX version does not support sharding with ' + \
'jax.experimental.custom_partitioning.')
@pytest.mark.skipif(not is_devices_enough(configs.device_count), reason='Num of GPU is not enough')
class TestCustomPartitioningOpsGenerator:
@pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
......@@ -32,33 +33,33 @@ class TestCustomPartitioningOpsGenerator:
zero_centered_gamma):
epsilon = 1e-6
custom_func = partial(helper.custom_layernorm,
custom_func = partial(ops.custom_layernorm,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type)
reference_func = partial(helper.reference_layernorm,
reference_func = partial(ops.reference_layernorm,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
batch_size, _, num_heads, head_dim = helper.qkv_shape
batch_size, _, num_heads, head_dim = ops.qkv_shape
hidden_size = num_heads*head_dim
input_shape = (batch_size, hidden_size)
other_shape = (hidden_size, )
x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=helper.dtype)
gamma_ = jnp.ones(other_shape, dtype=helper.dtype)
beta_ = jnp.ones(other_shape, dtype=helper.dtype)
x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype)
gamma_ = jnp.ones(other_shape, dtype=ops.dtype)
beta_ = jnp.ones(other_shape, dtype=ops.dtype)
x_spec, gamma_spec, beta_spec = helper.get_sharding_spec(mesh_names, sharding_type)
x_spec, gamma_spec, beta_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
beta_ = jax.device_put(beta_, NamedSharding(mesh, beta_spec))
helper.compare_ops(
ops.compare_ops(
custom_func, reference_func, collective_ref,
x_, gamma_, beta_, grad_args=(0, 1, 2), dtype=helper.dtype,
x_, gamma_, beta_, grad_args=(0, 1, 2), dtype=ops.dtype,
in_shardings=[x_spec, gamma_spec, beta_spec],
out_shardings=(None, (x_spec, gamma_spec, beta_spec))
)
......@@ -67,25 +68,25 @@ class TestCustomPartitioningOpsGenerator:
configs.layernorm_refs)
def test_rmsnorm(self, mesh_shape, mesh_names, sharding_type, collective_ref):
epsilon = 1e-6
custom_func = partial(helper.custom_rmsnorm, epsilon=epsilon,sharding_type=sharding_type)
reference_func = partial(helper.reference_rmsnorm, epsilon=epsilon)
custom_func = partial(ops.custom_rmsnorm, epsilon=epsilon,sharding_type=sharding_type)
reference_func = partial(ops.reference_rmsnorm, epsilon=epsilon)
batch_size, _, num_heads, head_dim = helper.qkv_shape
batch_size, _, num_heads, head_dim = ops.qkv_shape
hidden_size = num_heads*head_dim
input_shape = (batch_size, hidden_size)
other_shape = (hidden_size, )
x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=helper.dtype)
gamma_ = jnp.ones(other_shape, dtype=helper.dtype)
x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype)
gamma_ = jnp.ones(other_shape, dtype=ops.dtype)
x_spec, gamma_spec = helper.get_sharding_spec(mesh_names, sharding_type)
x_spec, gamma_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
helper.compare_ops(
ops.compare_ops(
custom_func, reference_func, collective_ref,
x_, gamma_, grad_args=(0, 1), dtype=helper.dtype,
x_, gamma_, grad_args=(0, 1), dtype=ops.dtype,
in_shardings=[x_spec, gamma_spec],
out_shardings=(None, (x_spec, gamma_spec))
)
......@@ -95,36 +96,36 @@ class TestCustomPartitioningOpsGenerator:
@pytest.mark.parametrize('softmax_type', configs.softmax_types)
def test_softmax(self, mesh_shape, mesh_names, sharding_type, collective_ref,
softmax_type):
batch_size, seq_len, num_heads, head_dim = helper.qkv_shape
batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
scale_factor = 1./jnp.sqrt(head_dim)
custom_func = partial(helper.custom_softmax,
custom_func = partial(ops.custom_softmax,
scale_factor=scale_factor,
softmax_type=softmax_type,
sharding_type=sharding_type)
reference_func = partial(helper.reference_softmax,
reference_func = partial(ops.reference_softmax,
scale_factor=scale_factor,
softmax_type=softmax_type)
input_size = (batch_size, num_heads, seq_len, seq_len)
x_ = random.normal(random.PRNGKey(1124), input_size, dtype=helper.dtype)
x_ = random.normal(random.PRNGKey(1124), input_size, dtype=ops.dtype)
pad_len = int(seq_len * helper.pad_ratio)
pad_len = int(seq_len * ops.pad_ratio)
valid_len = seq_len - pad_len
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=helper.dtype),
jnp.zeros((batch_size, pad_len), dtype=helper.dtype)),
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
axis=-1)
mask_ = helper.make_mask(tokens, tokens, AttnMaskType.PADDING_MASK)
mask_ = ops.make_mask(tokens, tokens, AttnMaskType.PADDING_MASK)
x_spec, mask_spec = helper.get_sharding_spec(mesh_names, sharding_type)
x_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
helper.compare_ops(
ops.compare_ops(
custom_func, reference_func, collective_ref,
(0), x_, mask_, grad_args=(0), dtype=helper.dtype,
(0), x_, mask_, grad_args=(0), dtype=ops.dtype,
in_shardings=[x_spec, mask_spec],
out_shardings=(None, (x_spec))
)
......@@ -135,25 +136,25 @@ class TestCustomPartitioningOpsGenerator:
@pytest.mark.parametrize('attn_mask_type', configs.self_attn_mask_types)
def test_self_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
attn_bias_type, attn_mask_type, backend):
batch_size, seq_len, num_heads, head_dim = helper.qkv_shape
helper.check_fused_attn_inputs(seq_len, seq_len, head_dim,
helper.pad_ratio, helper.dropout_prob,
batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
ops.check_fused_attn_inputs(seq_len, seq_len, head_dim,
ops.pad_ratio, ops.dropout_prob,
attn_bias_type, attn_mask_type, backend)
dropout_rng = random.PRNGKey(91023051)
split_rng = random.split(dropout_rng, configs.device_count)
scale_factor = 1./jnp.sqrt(head_dim)
custom_func = partial(helper.custom_self_fused_attn,
custom_func = partial(ops.custom_self_fused_attn,
rng_key=split_rng,
dropout_prob=helper.dropout_prob,
dropout_prob=ops.dropout_prob,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
sharding_type=sharding_type)
reference_func = partial(helper.reference_self_fused_attn,
reference_func = partial(ops.reference_self_fused_attn,
rng_key=dropout_rng,
dropout_prob=helper.dropout_prob,
dropout_prob=ops.dropout_prob,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor)
......@@ -162,27 +163,27 @@ class TestCustomPartitioningOpsGenerator:
subkeys = random.split(key, 2)
qkv_shape = (batch_size, seq_len, 3, num_heads, head_dim)
qkv_ = random.normal(subkeys[0], qkv_shape, dtype=helper.dtype)
qkv_ = random.normal(subkeys[0], qkv_shape, dtype=ops.dtype)
bias_shape = (1, num_heads, seq_len, seq_len)
bias_ = random.normal(subkeys[1], bias_shape, dtype=helper.dtype)
bias_ = random.normal(subkeys[1], bias_shape, dtype=ops.dtype)
pad_len = int(seq_len * helper.pad_ratio)
pad_len = int(seq_len * ops.pad_ratio)
valid_len = seq_len - pad_len
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=helper.dtype),
jnp.zeros((batch_size, pad_len), dtype=helper.dtype)),
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
axis=-1)
mask_ = helper.make_mask(tokens, tokens, attn_mask_type)
mask_ = ops.make_mask(tokens, tokens, attn_mask_type)
qkv_spec, bias_spec, mask_spec = helper.get_sharding_spec(mesh_names, sharding_type)
qkv_spec, bias_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
qkv_ = jax.device_put(qkv_, NamedSharding(mesh, qkv_spec))
bias_ = jax.device_put(bias_, NamedSharding(mesh, bias_spec))
mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
helper.compare_ops(
ops.compare_ops(
custom_func, reference_func, collective_ref,
qkv_, bias_, mask_, grad_args=(0, 1), dtype=helper.dtype,
qkv_, bias_, mask_, grad_args=(0, 1), dtype=ops.dtype,
in_shardings=[qkv_spec, bias_spec, mask_spec],
out_shardings=(None, (qkv_spec, bias_spec))
)
......@@ -192,24 +193,24 @@ class TestCustomPartitioningOpsGenerator:
@pytest.mark.parametrize('attn_mask_type', configs.cross_attn_mask_types)
def test_cross_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
attn_mask_type, backend):
batch_size, seq_len, num_heads, head_dim = helper.qkv_shape
helper.check_fused_attn_inputs(seq_len, seq_len, head_dim,
helper.pad_ratio, helper.dropout_prob,
batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
ops.check_fused_attn_inputs(seq_len, seq_len, head_dim,
ops.pad_ratio, ops.dropout_prob,
AttnBiasType.NO_BIAS, attn_mask_type, backend)
dropout_rng = random.PRNGKey(91023051)
split_rng = random.split(dropout_rng, configs.device_count)
scale_factor = 1./jnp.sqrt(head_dim)
custom_func = partial(helper.custom_cross_fused_attn,
custom_func = partial(ops.custom_cross_fused_attn,
rng_key=split_rng,
dropout_prob=helper.dropout_prob,
dropout_prob=ops.dropout_prob,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
sharding_type=sharding_type)
reference_func = partial(helper.reference_cross_fused_attn,
reference_func = partial(ops.reference_cross_fused_attn,
rng_key=split_rng,
dropout_prob=helper.dropout_prob,
dropout_prob=ops.dropout_prob,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor)
......@@ -217,27 +218,27 @@ class TestCustomPartitioningOpsGenerator:
subkeys = random.split(key, 2)
q_shape = (batch_size, seq_len, num_heads, head_dim)
q_ = random.normal(subkeys[0], q_shape, dtype=helper.dtype)
q_ = random.normal(subkeys[0], q_shape, dtype=ops.dtype)
kv_shape = (batch_size, seq_len, 2, num_heads, head_dim)
kv_ = random.normal(subkeys[1], kv_shape, dtype=helper.dtype)
kv_ = random.normal(subkeys[1], kv_shape, dtype=ops.dtype)
pad_len = int(seq_len * helper.pad_ratio)
pad_len = int(seq_len * ops.pad_ratio)
valid_len = seq_len - pad_len
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=helper.dtype),
jnp.zeros((batch_size, pad_len), dtype=helper.dtype)),
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
axis=-1)
mask_ = helper.make_mask(tokens, tokens, attn_mask_type)
mask_ = ops.make_mask(tokens, tokens, attn_mask_type)
q_spec, kv_spec, mask_spec = helper.get_sharding_spec(mesh_names, sharding_type)
q_spec, kv_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)):
with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
q_ = jax.device_put(q_, NamedSharding(mesh, q_spec))
kv_= jax.device_put(kv_, NamedSharding(mesh, kv_spec))
mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
helper.compare_ops(
ops.compare_ops(
custom_func, reference_func, collective_ref,
q_, kv_, mask_, grad_args=(0, 1), dtype=helper.dtype,
q_, kv_, mask_, grad_args=(0, 1), dtype=ops.dtype,
in_shardings=[q_spec, kv_spec, mask_spec],
out_shardings=(None, (q_spec, kv_spec))
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment