Unverified Commit ffa24475 authored by Keshav Balasubramanian's avatar Keshav Balasubramanian Committed by GitHub
Browse files

Ln force no weight sharding (#715)



* disallow sharding of layernorm learnable parameters; force duplication
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

* fix tests and support tensors for gamma/beta in layernorms
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

* reverting
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

* added tests for rank-1 gamma/beta sharding
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

* fix lint errors
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

---------
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>
parent 2d0ab27f
......@@ -18,7 +18,7 @@ def generate_configs():
if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')])
if is_devices_enough(4):
TP_size = 2
DP_size = 2
......
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
import warnings
import pytest
import jax
......@@ -20,7 +21,7 @@ DTYPES = [jnp.bfloat16, jnp.float32]
class TestDistributedLayernorm:
def generate_inputs(self, shape, mesh_resource, dtype):
def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
weight_shape = (shape[-1],)
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
......@@ -34,7 +35,7 @@ class TestDistributedLayernorm:
else:
raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(None)
g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
......@@ -54,8 +55,9 @@ class TestDistributedLayernorm:
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
@pytest.mark.parametrize('shard_weights', [False, True])
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype,
zero_centered_gamma):
zero_centered_gamma, shard_weights):
epsilon = 1e-6
ln_type = 'layernorm'
......@@ -74,7 +76,7 @@ class TestDistributedLayernorm:
return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \
self.generate_inputs(data_shape, mesh_resource, dtype)
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
......@@ -84,19 +86,35 @@ class TestDistributedLayernorm:
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
compare_ops(target_func,
ref_func, [x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"Layernorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype):
@pytest.mark.parametrize('shard_weights', [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights):
epsilon = 1e-6
ln_type = 'rmsnorm'
......@@ -111,7 +129,7 @@ class TestDistributedLayernorm:
return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \
self.generate_inputs(data_shape, mesh_resource, dtype)
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
......@@ -120,11 +138,26 @@ class TestDistributedLayernorm:
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
compare_ops(target_func,
ref_func, [x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)))
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma. We can catch
# and ignore that specific error here.
if g_pspec[-1] is None or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"RmsNorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)
......@@ -453,9 +453,21 @@ class LayerNormFwdPrimitive(BasePrimitive):
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
if b_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter beta " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
b_sharding = NamedSharding(mesh, PartitionSpec(*b_spec))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
......@@ -628,8 +640,15 @@ class LayerNormBwdPrimitive(BasePrimitive):
f"and hurt performance."
)
g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
f"of gamma and beta of Layernorm " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
return dx_sharding, dgamma_sharding, dbeta_sharding
@staticmethod
......@@ -643,12 +662,19 @@ class LayerNormBwdPrimitive(BasePrimitive):
f"and hurt performance."
)
g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
f"of gamma and beta of Layernorm " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(*g_b_spec)))
arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(None)))
def sharded_impl(dz, x, mu, rsigma, gamma):
local_dx, local_dgamma, local_dbeta = \
......@@ -828,8 +854,14 @@ class RmsNormFwdPrimitive(BasePrimitive):
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (x_sharding, g_sharding)
......@@ -982,8 +1014,13 @@ class RmsNormBwdPrimitive(BasePrimitive):
f"and hurt performance."
)
g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
return dx_sharding, dgamma_sharding
@staticmethod
......@@ -997,12 +1034,17 @@ class RmsNormBwdPrimitive(BasePrimitive):
f"and hurt performance."
)
g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
out_shardings = dx_sharding, dgamma_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(*g_spec)))
arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(None)))
def sharded_impl(dz, x, rsigma, gamma):
local_dx, local_dgamma = \
......@@ -4336,15 +4378,27 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[1])
b_spec = get_padded_spec(arg_infos[2])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
f"Does not support to shard hidden dim in {LayerNormFwdFp8Primitive.name}! " \
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
if b_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter beta " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
b_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
......@@ -4568,14 +4622,20 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment