Unverified Commit 04040957 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Bugfix for softmax primitives accepting invalid input sharding (#664)



* Softmax now forces XLA to unshard the hidden dimension with a warning. Unittests updated to check for numerics and warning with bad sharding
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* correcting cudnn-frontend version
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed mismatched output sharding
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* combined softmax tests and fixed code style/linting issues
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 8bba5eeb
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import warnings
import pytest import pytest
from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -26,7 +28,7 @@ class TestDistributedSoftmax: ...@@ -26,7 +28,7 @@ class TestDistributedSoftmax:
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype): def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding):
batch, _, sqelen, _ = shape batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
...@@ -35,25 +37,22 @@ class TestDistributedSoftmax: ...@@ -35,25 +37,22 @@ class TestDistributedSoftmax:
else: else:
mask = make_self_mask(batch, sqelen) mask = make_self_mask(batch, sqelen)
x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource, None, None) if not bad_sharding:
x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource,
None, None)
else:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None,
None, mesh_resource.tp_resource)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec) return (x, mask), (x_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @staticmethod
@pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]]) def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
@pytest.mark.parametrize(
'softmax_type',
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED])
@pytest.mark.parametrize('scale_factor', [1.0, 3.0])
@pytest.mark.parametrize('dtype', DTYPES)
def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
softmax_type, scale_factor, dtype):
def target_func(x, mask):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type)) return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
def ref_func(x, mask): @staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
bias = None bias = None
if mask is not None: if mask is not None:
bias = jax.lax.select(mask > 0, bias = jax.lax.select(mask > 0,
...@@ -64,8 +63,24 @@ class TestDistributedSoftmax: ...@@ -64,8 +63,24 @@ class TestDistributedSoftmax:
output = jax.nn.softmax(x * scale_factor) output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output) return jnp.mean(output)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
'softmax_type',
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED])
@pytest.mark.parametrize('scale_factor', [1.0, 3.0])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('bad_sharding', [False, True])
def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
softmax_type, scale_factor, dtype, bad_sharding):
target_func = partial(self.target_func,
scale_factor=scale_factor,
softmax_type=softmax_type)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = \ (x, mask), (x_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype) self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype, bad_sharding)
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
...@@ -73,6 +88,8 @@ class TestDistributedSoftmax: ...@@ -73,6 +88,8 @@ class TestDistributedSoftmax:
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func, compare_ops(target_func,
ref_func, [x_, mask_], ref_func, [x_, mask_],
collective_count_ref, collective_count_ref,
...@@ -81,3 +98,16 @@ class TestDistributedSoftmax: ...@@ -81,3 +98,16 @@ class TestDistributedSoftmax:
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec), in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,))) out_shardings=(None, (x_pspec,)))
except AssertionError as err:
# Softmax should still produce the correct numerical result with
# bad sharding. However, the collective count may not be the same
# when XLA is forced to unshard the hidden dimension. We can catch
# and ignore that specific error here.
if not bad_sharding or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Sharding the hidden dimension is not supported" in str(w), (
"Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension."
)
...@@ -1029,6 +1029,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -1029,6 +1029,7 @@ class SoftmaxPrimitive(BasePrimitive):
Softmax Primitive Softmax Primitive
""" """
max_k_seqlen_supported = 4096 max_k_seqlen_supported = 4096
name = "te_softmax_internal_placeholder"
@staticmethod @staticmethod
@abstractmethod @abstractmethod
...@@ -1118,26 +1119,37 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -1118,26 +1119,37 @@ class SoftmaxPrimitive(BasePrimitive):
out_bdims = logits_bdim out_bdims = logits_bdim
return primitive.bind(logits, scale_factor=scale_factor), out_bdims return primitive.bind(logits, scale_factor=scale_factor), out_bdims
@staticmethod @classmethod
def forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
""" """
softmax_forward infer_sharding_from_operands softmax_forward infer_sharding_from_operands
""" """
del scale_factor, result_infos # Unused. del scale_factor, result_infos # Unused.
logits_spec = get_padded_spec(arg_infos[0]) logits_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) if logits_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
return out_sharding return out_sharding
@staticmethod @classmethod
def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos): def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
""" """
softmax_forward partitioning softmax_forward partitioning
""" """
del result_infos del result_infos
logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0]))) logits_spec = get_padded_spec(arg_infos[0])
out_spec = logits_spec if logits_spec[-1] is not None:
arg_shardings = (logits_spec,) warnings.warn(
out_shardings = out_spec f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
arg_shardings = (out_shardings,)
impl = partial(impl, scale_factor=scale_factor) impl = partial(impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
...@@ -1154,7 +1166,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -1154,7 +1166,7 @@ class SoftmaxPrimitive(BasePrimitive):
assert dz_aval.shape == softmax_out_aval.shape assert dz_aval.shape == softmax_out_aval.shape
dx_aval = core.raise_to_shaped(softmax_out_aval) dx_aval = core.raise_to_shaped(dz_aval)
return dx_aval return dx_aval
@staticmethod @staticmethod
...@@ -1177,7 +1189,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -1177,7 +1189,7 @@ class SoftmaxPrimitive(BasePrimitive):
softmax_out_type = ir.RankedTensorType(softmax_out.type) softmax_out_type = ir.RankedTensorType(softmax_out.type)
softmax_out_shape = softmax_out_type.shape softmax_out_shape = softmax_out_type.shape
out_types = [ir.RankedTensorType.get(softmax_out_shape, softmax_out_type.element_type)] out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
operands = [dz, softmax_out] operands = [dz, softmax_out]
operand_shapes = [dz_shape, softmax_out_shape] operand_shapes = [dz_shape, softmax_out_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
...@@ -1211,27 +1223,44 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -1211,27 +1223,44 @@ class SoftmaxPrimitive(BasePrimitive):
out_bdims = softmax_out_bdim out_bdims = softmax_out_bdim
return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims
@staticmethod @classmethod
def backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
""" """
softmax_backward infer_sharding_from_operands softmax_backward infer_sharding_from_operands
""" """
del scale_factor, result_infos # Unused. del scale_factor, result_infos # Unused.
softmax_out_spec = get_padded_spec(arg_infos[1]) dz_spec = get_padded_spec(arg_infos[0])
dx_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec)) if dz_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
return dx_sharding return dx_sharding
@staticmethod @classmethod
def backward_partition(impl, scale_factor, mesh, arg_infos, result_infos): def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
""" """
softmax_backward partition softmax_backward partition
""" """
del result_infos del result_infos
dz_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
softmax_out_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) dz_spec = get_padded_spec(arg_infos[0])
dx_spec = softmax_out_spec softmax_out_spec = get_padded_spec(arg_infos[1])
arg_shardings = (dz_spec, softmax_out_spec) if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
out_shardings = dx_spec warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
dx_sharding = dz_sharding
arg_shardings = (dz_sharding, softmax_out_sharding)
out_shardings = dx_sharding
impl = partial(impl, scale_factor=scale_factor) impl = partial(impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
...@@ -1296,13 +1325,15 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1296,13 +1325,15 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
result_infos) scale_factor, mesh, arg_infos, result_infos
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, scale_factor, return ScaledSoftmaxFwdPrimitive.forward_partition(
mesh, arg_infos, result_infos) ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
register_primitive(ScaledSoftmaxFwdPrimitive) register_primitive(ScaledSoftmaxFwdPrimitive)
...@@ -1370,13 +1401,15 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -1370,13 +1401,15 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
result_infos) scale_factor, mesh, arg_infos, result_infos
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, scale_factor, return ScaledSoftmaxBwdPrimitive.backward_partition(
mesh, arg_infos, result_infos) ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
register_primitive(ScaledSoftmaxBwdPrimitive) register_primitive(ScaledSoftmaxBwdPrimitive)
...@@ -1505,20 +1538,15 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1505,20 +1538,15 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
del scale_factor, result_infos # Unused. return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
logits_spec = get_padded_spec(arg_infos[0]) scale_factor, mesh, arg_infos,result_infos
out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) )
return out_sharding
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
del result_infos return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0]))) ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
mask_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) )
arg_shardings = (logits_spec, mask_spec)
out_shardings = logits_spec
impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings
register_primitive(ScaledMaskedSoftmaxFwdPrimitive) register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
...@@ -1589,13 +1617,15 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -1589,13 +1617,15 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
result_infos) scale_factor, mesh, arg_infos, result_infos
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_partition(ScaledMaskedSoftmaxBwdPrimitive.impl, return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
scale_factor, mesh, arg_infos, result_infos) ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
register_primitive(ScaledMaskedSoftmaxBwdPrimitive) register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
...@@ -1676,13 +1706,16 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1676,13 +1706,16 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
result_infos) scale_factor, mesh, arg_infos, result_infos
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_partition(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
scale_factor, mesh, arg_infos, result_infos) ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh,
arg_infos, result_infos
)
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
...@@ -1753,13 +1786,16 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -1753,13 +1786,16 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
result_infos) scale_factor, mesh, arg_infos, result_infos
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_partition(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
scale_factor, mesh, arg_infos, result_infos) ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh,
arg_infos, result_infos
)
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment