Unverified Commit 30cad990 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Bugfix for insufficient GPUs crash in distributed ops (#505)



* Fixed minor bug with DistributedConfigsHelper prematurely crashing the test for insufficient GPUs before @pytest.skip condition.
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Update tests/jax/distributed_configs_helper.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug PyTest errors when running on single-GPU system
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 74eb7c33
...@@ -8,14 +8,22 @@ from transformer_engine.jax.softmax import SoftmaxType ...@@ -8,14 +8,22 @@ from transformer_engine.jax.softmax import SoftmaxType
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType\ from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType\
class ShardingConfigs(object): class DistributedConfigsHelper(object):
def __init__(self, num_gpus=len(jax.devices())): def __init__(self, num_gpus=len(jax.devices())):
super().__init__() super().__init__()
if num_gpus < 2:
raise ValueError(f"ShardingConfig: Need at least 2 GPUs.")
self.device_count = min(num_gpus, 8) self.device_count = min(num_gpus, 8)
if self.device_count < 2:
self.layernorm_refs = []
self.softmax_types = []
self.softmax_refs = []
self.self_attn_bias_types = []
self.self_attn_mask_types = []
self.self_attn_refs = []
self.cross_attn_mask_types = []
self.cross_attn_refs = []
return
mesh_configs = [ mesh_configs = [
((self.device_count, 1), ("dp", None), ShardingType.DP), ((self.device_count, 1), ("dp", None), ShardingType.DP),
((self.device_count, 1), ("tp", None), ShardingType.TP_COL), ((self.device_count, 1), ("tp", None), ShardingType.TP_COL),
......
...@@ -44,7 +44,7 @@ def fixture_backend(request): ...@@ -44,7 +44,7 @@ def fixture_backend(request):
@dataclass @dataclass
class CustomOpsTestHelper: class DistributedOpsHelper:
qkv_shape: Tuple[int,int,int,int] = (32, 128, 16, 64) qkv_shape: Tuple[int,int,int,int] = (32, 128, 16, 64)
pad_ratio: float = 0.3 pad_ratio: float = 0.3
dropout_prob: float = 0.1 dropout_prob: float = 0.1
......
...@@ -11,18 +11,19 @@ from jax import random ...@@ -11,18 +11,19 @@ from jax import random
from jax.sharding import NamedSharding from jax.sharding import NamedSharding
from utils import is_devices_enough from utils import is_devices_enough
from sharding_configs import * from distributed_configs_helper import *
from custom_ops_helper import * from distributed_ops_helper import *
from transformer_engine.jax.sharding import global_shard_guard from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
configs = ShardingConfigs() # default device count is len(jax.devices()) configs = DistributedConfigsHelper() # default device count is len(jax.devices())
helper = CustomOpsTestHelper() ops = DistributedOpsHelper() # default data type is jnp.float16
@pytest.mark.skipif(not helper.use_custom_partitioning(), @pytest.mark.skipif(not is_devices_enough(configs.device_count),
reason='Insufficient number of GPUs, need at least 2.')
@pytest.mark.skipif(not ops.use_custom_partitioning(),
reason='TE/JAX version does not support sharding with ' + \ reason='TE/JAX version does not support sharding with ' + \
'jax.experimental.custom_partitioning.') 'jax.experimental.custom_partitioning.')
@pytest.mark.skipif(not is_devices_enough(configs.device_count), reason='Num of GPU is not enough')
class TestCustomPartitioningOpsGenerator: class TestCustomPartitioningOpsGenerator:
@pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref', @pytest.mark.parametrize('mesh_shape, mesh_names, sharding_type, collective_ref',
...@@ -32,33 +33,33 @@ class TestCustomPartitioningOpsGenerator: ...@@ -32,33 +33,33 @@ class TestCustomPartitioningOpsGenerator:
zero_centered_gamma): zero_centered_gamma):
epsilon = 1e-6 epsilon = 1e-6
custom_func = partial(helper.custom_layernorm, custom_func = partial(ops.custom_layernorm,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
sharding_type=sharding_type) sharding_type=sharding_type)
reference_func = partial(helper.reference_layernorm, reference_func = partial(ops.reference_layernorm,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon) epsilon=epsilon)
batch_size, _, num_heads, head_dim = helper.qkv_shape batch_size, _, num_heads, head_dim = ops.qkv_shape
hidden_size = num_heads*head_dim hidden_size = num_heads*head_dim
input_shape = (batch_size, hidden_size) input_shape = (batch_size, hidden_size)
other_shape = (hidden_size, ) other_shape = (hidden_size, )
x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=helper.dtype) x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype)
gamma_ = jnp.ones(other_shape, dtype=helper.dtype) gamma_ = jnp.ones(other_shape, dtype=ops.dtype)
beta_ = jnp.ones(other_shape, dtype=helper.dtype) beta_ = jnp.ones(other_shape, dtype=ops.dtype)
x_spec, gamma_spec, beta_spec = helper.get_sharding_spec(mesh_names, sharding_type) x_spec, gamma_spec, beta_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names) mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)): with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
x_ = jax.device_put(x_, NamedSharding(mesh, x_spec)) x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec)) gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
beta_ = jax.device_put(beta_, NamedSharding(mesh, beta_spec)) beta_ = jax.device_put(beta_, NamedSharding(mesh, beta_spec))
helper.compare_ops( ops.compare_ops(
custom_func, reference_func, collective_ref, custom_func, reference_func, collective_ref,
x_, gamma_, beta_, grad_args=(0, 1, 2), dtype=helper.dtype, x_, gamma_, beta_, grad_args=(0, 1, 2), dtype=ops.dtype,
in_shardings=[x_spec, gamma_spec, beta_spec], in_shardings=[x_spec, gamma_spec, beta_spec],
out_shardings=(None, (x_spec, gamma_spec, beta_spec)) out_shardings=(None, (x_spec, gamma_spec, beta_spec))
) )
...@@ -67,25 +68,25 @@ class TestCustomPartitioningOpsGenerator: ...@@ -67,25 +68,25 @@ class TestCustomPartitioningOpsGenerator:
configs.layernorm_refs) configs.layernorm_refs)
def test_rmsnorm(self, mesh_shape, mesh_names, sharding_type, collective_ref): def test_rmsnorm(self, mesh_shape, mesh_names, sharding_type, collective_ref):
epsilon = 1e-6 epsilon = 1e-6
custom_func = partial(helper.custom_rmsnorm, epsilon=epsilon,sharding_type=sharding_type) custom_func = partial(ops.custom_rmsnorm, epsilon=epsilon,sharding_type=sharding_type)
reference_func = partial(helper.reference_rmsnorm, epsilon=epsilon) reference_func = partial(ops.reference_rmsnorm, epsilon=epsilon)
batch_size, _, num_heads, head_dim = helper.qkv_shape batch_size, _, num_heads, head_dim = ops.qkv_shape
hidden_size = num_heads*head_dim hidden_size = num_heads*head_dim
input_shape = (batch_size, hidden_size) input_shape = (batch_size, hidden_size)
other_shape = (hidden_size, ) other_shape = (hidden_size, )
x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=helper.dtype) x_ = random.normal(random.PRNGKey(1124), input_shape, dtype=ops.dtype)
gamma_ = jnp.ones(other_shape, dtype=helper.dtype) gamma_ = jnp.ones(other_shape, dtype=ops.dtype)
x_spec, gamma_spec = helper.get_sharding_spec(mesh_names, sharding_type) x_spec, gamma_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names) mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)): with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
x_ = jax.device_put(x_, NamedSharding(mesh, x_spec)) x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec)) gamma_ = jax.device_put(gamma_, NamedSharding(mesh, gamma_spec))
helper.compare_ops( ops.compare_ops(
custom_func, reference_func, collective_ref, custom_func, reference_func, collective_ref,
x_, gamma_, grad_args=(0, 1), dtype=helper.dtype, x_, gamma_, grad_args=(0, 1), dtype=ops.dtype,
in_shardings=[x_spec, gamma_spec], in_shardings=[x_spec, gamma_spec],
out_shardings=(None, (x_spec, gamma_spec)) out_shardings=(None, (x_spec, gamma_spec))
) )
...@@ -95,36 +96,36 @@ class TestCustomPartitioningOpsGenerator: ...@@ -95,36 +96,36 @@ class TestCustomPartitioningOpsGenerator:
@pytest.mark.parametrize('softmax_type', configs.softmax_types) @pytest.mark.parametrize('softmax_type', configs.softmax_types)
def test_softmax(self, mesh_shape, mesh_names, sharding_type, collective_ref, def test_softmax(self, mesh_shape, mesh_names, sharding_type, collective_ref,
softmax_type): softmax_type):
batch_size, seq_len, num_heads, head_dim = helper.qkv_shape batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
scale_factor = 1./jnp.sqrt(head_dim) scale_factor = 1./jnp.sqrt(head_dim)
custom_func = partial(helper.custom_softmax, custom_func = partial(ops.custom_softmax,
scale_factor=scale_factor, scale_factor=scale_factor,
softmax_type=softmax_type, softmax_type=softmax_type,
sharding_type=sharding_type) sharding_type=sharding_type)
reference_func = partial(helper.reference_softmax, reference_func = partial(ops.reference_softmax,
scale_factor=scale_factor, scale_factor=scale_factor,
softmax_type=softmax_type) softmax_type=softmax_type)
input_size = (batch_size, num_heads, seq_len, seq_len) input_size = (batch_size, num_heads, seq_len, seq_len)
x_ = random.normal(random.PRNGKey(1124), input_size, dtype=helper.dtype) x_ = random.normal(random.PRNGKey(1124), input_size, dtype=ops.dtype)
pad_len = int(seq_len * helper.pad_ratio) pad_len = int(seq_len * ops.pad_ratio)
valid_len = seq_len - pad_len valid_len = seq_len - pad_len
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=helper.dtype), tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
jnp.zeros((batch_size, pad_len), dtype=helper.dtype)), jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
axis=-1) axis=-1)
mask_ = helper.make_mask(tokens, tokens, AttnMaskType.PADDING_MASK) mask_ = ops.make_mask(tokens, tokens, AttnMaskType.PADDING_MASK)
x_spec, mask_spec = helper.get_sharding_spec(mesh_names, sharding_type) x_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names) mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)): with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
x_ = jax.device_put(x_, NamedSharding(mesh, x_spec)) x_ = jax.device_put(x_, NamedSharding(mesh, x_spec))
mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec)) mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
helper.compare_ops( ops.compare_ops(
custom_func, reference_func, collective_ref, custom_func, reference_func, collective_ref,
(0), x_, mask_, grad_args=(0), dtype=helper.dtype, (0), x_, mask_, grad_args=(0), dtype=ops.dtype,
in_shardings=[x_spec, mask_spec], in_shardings=[x_spec, mask_spec],
out_shardings=(None, (x_spec)) out_shardings=(None, (x_spec))
) )
...@@ -135,25 +136,25 @@ class TestCustomPartitioningOpsGenerator: ...@@ -135,25 +136,25 @@ class TestCustomPartitioningOpsGenerator:
@pytest.mark.parametrize('attn_mask_type', configs.self_attn_mask_types) @pytest.mark.parametrize('attn_mask_type', configs.self_attn_mask_types)
def test_self_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref, def test_self_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
attn_bias_type, attn_mask_type, backend): attn_bias_type, attn_mask_type, backend):
batch_size, seq_len, num_heads, head_dim = helper.qkv_shape batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
helper.check_fused_attn_inputs(seq_len, seq_len, head_dim, ops.check_fused_attn_inputs(seq_len, seq_len, head_dim,
helper.pad_ratio, helper.dropout_prob, ops.pad_ratio, ops.dropout_prob,
attn_bias_type, attn_mask_type, backend) attn_bias_type, attn_mask_type, backend)
dropout_rng = random.PRNGKey(91023051) dropout_rng = random.PRNGKey(91023051)
split_rng = random.split(dropout_rng, configs.device_count) split_rng = random.split(dropout_rng, configs.device_count)
scale_factor = 1./jnp.sqrt(head_dim) scale_factor = 1./jnp.sqrt(head_dim)
custom_func = partial(helper.custom_self_fused_attn, custom_func = partial(ops.custom_self_fused_attn,
rng_key=split_rng, rng_key=split_rng,
dropout_prob=helper.dropout_prob, dropout_prob=ops.dropout_prob,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
sharding_type=sharding_type) sharding_type=sharding_type)
reference_func = partial(helper.reference_self_fused_attn, reference_func = partial(ops.reference_self_fused_attn,
rng_key=dropout_rng, rng_key=dropout_rng,
dropout_prob=helper.dropout_prob, dropout_prob=ops.dropout_prob,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor) scaling_factor=scale_factor)
...@@ -162,27 +163,27 @@ class TestCustomPartitioningOpsGenerator: ...@@ -162,27 +163,27 @@ class TestCustomPartitioningOpsGenerator:
subkeys = random.split(key, 2) subkeys = random.split(key, 2)
qkv_shape = (batch_size, seq_len, 3, num_heads, head_dim) qkv_shape = (batch_size, seq_len, 3, num_heads, head_dim)
qkv_ = random.normal(subkeys[0], qkv_shape, dtype=helper.dtype) qkv_ = random.normal(subkeys[0], qkv_shape, dtype=ops.dtype)
bias_shape = (1, num_heads, seq_len, seq_len) bias_shape = (1, num_heads, seq_len, seq_len)
bias_ = random.normal(subkeys[1], bias_shape, dtype=helper.dtype) bias_ = random.normal(subkeys[1], bias_shape, dtype=ops.dtype)
pad_len = int(seq_len * helper.pad_ratio) pad_len = int(seq_len * ops.pad_ratio)
valid_len = seq_len - pad_len valid_len = seq_len - pad_len
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=helper.dtype), tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
jnp.zeros((batch_size, pad_len), dtype=helper.dtype)), jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
axis=-1) axis=-1)
mask_ = helper.make_mask(tokens, tokens, attn_mask_type) mask_ = ops.make_mask(tokens, tokens, attn_mask_type)
qkv_spec, bias_spec, mask_spec = helper.get_sharding_spec(mesh_names, sharding_type) qkv_spec, bias_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names) mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)): with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
qkv_ = jax.device_put(qkv_, NamedSharding(mesh, qkv_spec)) qkv_ = jax.device_put(qkv_, NamedSharding(mesh, qkv_spec))
bias_ = jax.device_put(bias_, NamedSharding(mesh, bias_spec)) bias_ = jax.device_put(bias_, NamedSharding(mesh, bias_spec))
mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec)) mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
helper.compare_ops( ops.compare_ops(
custom_func, reference_func, collective_ref, custom_func, reference_func, collective_ref,
qkv_, bias_, mask_, grad_args=(0, 1), dtype=helper.dtype, qkv_, bias_, mask_, grad_args=(0, 1), dtype=ops.dtype,
in_shardings=[qkv_spec, bias_spec, mask_spec], in_shardings=[qkv_spec, bias_spec, mask_spec],
out_shardings=(None, (qkv_spec, bias_spec)) out_shardings=(None, (qkv_spec, bias_spec))
) )
...@@ -192,24 +193,24 @@ class TestCustomPartitioningOpsGenerator: ...@@ -192,24 +193,24 @@ class TestCustomPartitioningOpsGenerator:
@pytest.mark.parametrize('attn_mask_type', configs.cross_attn_mask_types) @pytest.mark.parametrize('attn_mask_type', configs.cross_attn_mask_types)
def test_cross_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref, def test_cross_fused_attn(self, mesh_shape, mesh_names, sharding_type, collective_ref,
attn_mask_type, backend): attn_mask_type, backend):
batch_size, seq_len, num_heads, head_dim = helper.qkv_shape batch_size, seq_len, num_heads, head_dim = ops.qkv_shape
helper.check_fused_attn_inputs(seq_len, seq_len, head_dim, ops.check_fused_attn_inputs(seq_len, seq_len, head_dim,
helper.pad_ratio, helper.dropout_prob, ops.pad_ratio, ops.dropout_prob,
AttnBiasType.NO_BIAS, attn_mask_type, backend) AttnBiasType.NO_BIAS, attn_mask_type, backend)
dropout_rng = random.PRNGKey(91023051) dropout_rng = random.PRNGKey(91023051)
split_rng = random.split(dropout_rng, configs.device_count) split_rng = random.split(dropout_rng, configs.device_count)
scale_factor = 1./jnp.sqrt(head_dim) scale_factor = 1./jnp.sqrt(head_dim)
custom_func = partial(helper.custom_cross_fused_attn, custom_func = partial(ops.custom_cross_fused_attn,
rng_key=split_rng, rng_key=split_rng,
dropout_prob=helper.dropout_prob, dropout_prob=ops.dropout_prob,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
sharding_type=sharding_type) sharding_type=sharding_type)
reference_func = partial(helper.reference_cross_fused_attn, reference_func = partial(ops.reference_cross_fused_attn,
rng_key=split_rng, rng_key=split_rng,
dropout_prob=helper.dropout_prob, dropout_prob=ops.dropout_prob,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor) scaling_factor=scale_factor)
...@@ -217,27 +218,27 @@ class TestCustomPartitioningOpsGenerator: ...@@ -217,27 +218,27 @@ class TestCustomPartitioningOpsGenerator:
subkeys = random.split(key, 2) subkeys = random.split(key, 2)
q_shape = (batch_size, seq_len, num_heads, head_dim) q_shape = (batch_size, seq_len, num_heads, head_dim)
q_ = random.normal(subkeys[0], q_shape, dtype=helper.dtype) q_ = random.normal(subkeys[0], q_shape, dtype=ops.dtype)
kv_shape = (batch_size, seq_len, 2, num_heads, head_dim) kv_shape = (batch_size, seq_len, 2, num_heads, head_dim)
kv_ = random.normal(subkeys[1], kv_shape, dtype=helper.dtype) kv_ = random.normal(subkeys[1], kv_shape, dtype=ops.dtype)
pad_len = int(seq_len * helper.pad_ratio) pad_len = int(seq_len * ops.pad_ratio)
valid_len = seq_len - pad_len valid_len = seq_len - pad_len
tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=helper.dtype), tokens = jnp.concatenate((jnp.ones((batch_size, valid_len), dtype=ops.dtype),
jnp.zeros((batch_size, pad_len), dtype=helper.dtype)), jnp.zeros((batch_size, pad_len), dtype=ops.dtype)),
axis=-1) axis=-1)
mask_ = helper.make_mask(tokens, tokens, attn_mask_type) mask_ = ops.make_mask(tokens, tokens, attn_mask_type)
q_spec, kv_spec, mask_spec = helper.get_sharding_spec(mesh_names, sharding_type) q_spec, kv_spec, mask_spec = ops.get_sharding_spec(mesh_names, sharding_type)
devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:configs.device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_names) mesh = jax.sharding.Mesh(devices, mesh_names)
with mesh, global_shard_guard(helper.get_sharding_resource(mesh_names, sharding_type)): with mesh, global_shard_guard(ops.get_sharding_resource(mesh_names, sharding_type)):
q_ = jax.device_put(q_, NamedSharding(mesh, q_spec)) q_ = jax.device_put(q_, NamedSharding(mesh, q_spec))
kv_= jax.device_put(kv_, NamedSharding(mesh, kv_spec)) kv_= jax.device_put(kv_, NamedSharding(mesh, kv_spec))
mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec)) mask_ = jax.device_put(mask_, NamedSharding(mesh, mask_spec))
helper.compare_ops( ops.compare_ops(
custom_func, reference_func, collective_ref, custom_func, reference_func, collective_ref,
q_, kv_, mask_, grad_args=(0, 1), dtype=helper.dtype, q_, kv_, mask_, grad_args=(0, 1), dtype=ops.dtype,
in_shardings=[q_spec, kv_spec, mask_spec], in_shardings=[q_spec, kv_spec, mask_spec],
out_shardings=(None, (q_spec, kv_spec)) out_shardings=(None, (q_spec, kv_spec))
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment