Unverified Commit ffa24475 authored by Keshav Balasubramanian's avatar Keshav Balasubramanian Committed by GitHub
Browse files

Ln force no weight sharding (#715)



* disallow sharding of layernorm learnable parameters; force duplication
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

* fix tests and support tensors for gamma/beta in layernorms
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

* reverting
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

* added tests for rank-1 gamma/beta sharding
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

* fix lint errors
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>

---------
Signed-off-by: default avatarKeshav <keshavb@nvidia.com>
parent 2d0ab27f
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import warnings
import pytest import pytest
import jax import jax
...@@ -20,7 +21,7 @@ DTYPES = [jnp.bfloat16, jnp.float32] ...@@ -20,7 +21,7 @@ DTYPES = [jnp.bfloat16, jnp.float32]
class TestDistributedLayernorm: class TestDistributedLayernorm:
def generate_inputs(self, shape, mesh_resource, dtype): def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
weight_shape = (shape[-1],) weight_shape = (shape[-1],)
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
...@@ -34,7 +35,7 @@ class TestDistributedLayernorm: ...@@ -34,7 +35,7 @@ class TestDistributedLayernorm:
else: else:
raise NotImplementedError raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(None) g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec) return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
...@@ -54,8 +55,9 @@ class TestDistributedLayernorm: ...@@ -54,8 +55,9 @@ class TestDistributedLayernorm:
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('zero_centered_gamma', [False, True]) @pytest.mark.parametrize('zero_centered_gamma', [False, True])
@pytest.mark.parametrize('shard_weights', [False, True])
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype,
zero_centered_gamma): zero_centered_gamma, shard_weights):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'layernorm' ln_type = 'layernorm'
...@@ -74,7 +76,7 @@ class TestDistributedLayernorm: ...@@ -74,7 +76,7 @@ class TestDistributedLayernorm:
return jnp.mean(output) return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \ (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \
self.generate_inputs(data_shape, mesh_resource, dtype) self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype) data_shape, dtype)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
...@@ -84,6 +86,8 @@ class TestDistributedLayernorm: ...@@ -84,6 +86,8 @@ class TestDistributedLayernorm:
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func, compare_ops(target_func,
ref_func, [x_, gamma_, beta_], ref_func, [x_, gamma_, beta_],
collective_count_ref, collective_count_ref,
...@@ -92,11 +96,25 @@ class TestDistributedLayernorm: ...@@ -92,11 +96,25 @@ class TestDistributedLayernorm:
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec), in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec))) out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"Layernorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('dtype', DTYPES)
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype): @pytest.mark.parametrize('shard_weights', [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'rmsnorm' ln_type = 'rmsnorm'
...@@ -111,7 +129,7 @@ class TestDistributedLayernorm: ...@@ -111,7 +129,7 @@ class TestDistributedLayernorm:
return jnp.mean(output) return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \ (x, gamma, _), (x_pspec, g_pspec, _) = \
self.generate_inputs(data_shape, mesh_resource, dtype) self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype) data_shape, dtype)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
...@@ -120,6 +138,8 @@ class TestDistributedLayernorm: ...@@ -120,6 +138,8 @@ class TestDistributedLayernorm:
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func, compare_ops(target_func,
ref_func, [x_, gamma_], ref_func, [x_, gamma_],
collective_count_ref, collective_count_ref,
...@@ -128,3 +148,16 @@ class TestDistributedLayernorm: ...@@ -128,3 +148,16 @@ class TestDistributedLayernorm:
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec), in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec))) out_shardings=(None, (x_pspec, g_pspec)))
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma. We can catch
# and ignore that specific error here.
if g_pspec[-1] is None or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"RmsNorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)
...@@ -453,9 +453,21 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -453,9 +453,21 @@ class LayerNormFwdPrimitive(BasePrimitive):
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance." f"and hurt performance."
) )
if g_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
if b_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter beta " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(*b_spec)) b_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
...@@ -628,8 +640,15 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -628,8 +640,15 @@ class LayerNormBwdPrimitive(BasePrimitive):
f"and hurt performance." f"and hurt performance."
) )
g_b_spec = get_padded_spec(arg_infos[4]) g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
f"of gamma and beta of Layernorm " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec)) dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
return dx_sharding, dgamma_sharding, dbeta_sharding return dx_sharding, dgamma_sharding, dbeta_sharding
@staticmethod @staticmethod
...@@ -643,12 +662,19 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -643,12 +662,19 @@ class LayerNormBwdPrimitive(BasePrimitive):
f"and hurt performance." f"and hurt performance."
) )
g_b_spec = get_padded_spec(arg_infos[4]) g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
f"of gamma and beta of Layernorm " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec)) dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2 mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(*g_b_spec))) arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(None)))
def sharded_impl(dz, x, mu, rsigma, gamma): def sharded_impl(dz, x, mu, rsigma, gamma):
local_dx, local_dgamma, local_dbeta = \ local_dx, local_dgamma, local_dbeta = \
...@@ -828,8 +854,14 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -828,8 +854,14 @@ class RmsNormFwdPrimitive(BasePrimitive):
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance." f"and hurt performance."
) )
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (x_sharding, g_sharding) arg_shardings = (x_sharding, g_sharding)
...@@ -982,8 +1014,13 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -982,8 +1014,13 @@ class RmsNormBwdPrimitive(BasePrimitive):
f"and hurt performance." f"and hurt performance."
) )
g_spec = get_padded_spec(arg_infos[3]) g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
return dx_sharding, dgamma_sharding return dx_sharding, dgamma_sharding
@staticmethod @staticmethod
...@@ -997,12 +1034,17 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -997,12 +1034,17 @@ class RmsNormBwdPrimitive(BasePrimitive):
f"and hurt performance." f"and hurt performance."
) )
g_spec = get_padded_spec(arg_infos[3]) g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
out_shardings = dx_sharding, dgamma_sharding out_shardings = dx_sharding, dgamma_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(*g_spec))) arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(None)))
def sharded_impl(dz, x, rsigma, gamma): def sharded_impl(dz, x, rsigma, gamma):
local_dx, local_dgamma = \ local_dx, local_dgamma = \
...@@ -4336,15 +4378,27 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -4336,15 +4378,27 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[1])
b_spec = get_padded_spec(arg_infos[2])
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \ f"Does not support to shard hidden dim in {LayerNormFwdFp8Primitive.name}! " \
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance." f"and hurt performance."
) )
if g_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
if b_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter beta " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) b_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding( mu_sharding = rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1])) mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
...@@ -4568,14 +4622,20 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -4568,14 +4622,20 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
def partition(out_dtype, epsilon, mesh, arg_infos, result_infos): def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \ f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance." f"and hurt performance."
) )
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1])) rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment