Unverified Commit 04040957 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Bugfix for softmax primitives accepting invalid input sharding (#664)



* Softmax now forces XLA to unshard the hidden dimension with a warning. Unittests updated to check for numerics and warning with bad sharding
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* correcting cudnn-frontend version
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed mismatched output sharding
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* combined softmax tests and fixed code style/linting issues
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 8bba5eeb
......@@ -2,7 +2,9 @@
#
# See LICENSE for license information.
import warnings
import pytest
from functools import partial
import jax
import jax.numpy as jnp
......@@ -26,7 +28,7 @@ class TestDistributedSoftmax:
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype):
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding):
batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
......@@ -35,25 +37,22 @@ class TestDistributedSoftmax:
else:
mask = make_self_mask(batch, sqelen)
x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource, None, None)
if not bad_sharding:
x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource,
None, None)
else:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None,
None, mesh_resource.tp_resource)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
'softmax_type',
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED])
@pytest.mark.parametrize('scale_factor', [1.0, 3.0])
@pytest.mark.parametrize('dtype', DTYPES)
def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
softmax_type, scale_factor, dtype):
def target_func(x, mask):
@staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
def ref_func(x, mask):
@staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
bias = None
if mask is not None:
bias = jax.lax.select(mask > 0,
......@@ -64,8 +63,24 @@ class TestDistributedSoftmax:
output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
'softmax_type',
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED])
@pytest.mark.parametrize('scale_factor', [1.0, 3.0])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('bad_sharding', [False, True])
def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
softmax_type, scale_factor, dtype, bad_sharding):
target_func = partial(self.target_func,
scale_factor=scale_factor,
softmax_type=softmax_type)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype)
self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype, bad_sharding)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
......@@ -73,6 +88,8 @@ class TestDistributedSoftmax:
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, mask_],
collective_count_ref,
......@@ -81,3 +98,16 @@ class TestDistributedSoftmax:
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)))
except AssertionError as err:
# Softmax should still produce the correct numerical result with
# bad sharding. However, the collective count may not be the same
# when XLA is forced to unshard the hidden dimension. We can catch
# and ignore that specific error here.
if not bad_sharding or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Sharding the hidden dimension is not supported" in str(w), (
"Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension."
)
......@@ -1029,6 +1029,7 @@ class SoftmaxPrimitive(BasePrimitive):
Softmax Primitive
"""
max_k_seqlen_supported = 4096
name = "te_softmax_internal_placeholder"
@staticmethod
@abstractmethod
......@@ -1118,26 +1119,37 @@ class SoftmaxPrimitive(BasePrimitive):
out_bdims = logits_bdim
return primitive.bind(logits, scale_factor=scale_factor), out_bdims
@staticmethod
def forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
@classmethod
def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
"""
softmax_forward infer_sharding_from_operands
"""
del scale_factor, result_infos # Unused.
logits_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
if logits_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
return out_sharding
@staticmethod
def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
@classmethod
def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
"""
softmax_forward partitioning
"""
del result_infos
logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
out_spec = logits_spec
arg_shardings = (logits_spec,)
out_shardings = out_spec
logits_spec = get_padded_spec(arg_infos[0])
if logits_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
arg_shardings = (out_shardings,)
impl = partial(impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings
......@@ -1154,7 +1166,7 @@ class SoftmaxPrimitive(BasePrimitive):
assert dz_aval.shape == softmax_out_aval.shape
dx_aval = core.raise_to_shaped(softmax_out_aval)
dx_aval = core.raise_to_shaped(dz_aval)
return dx_aval
@staticmethod
......@@ -1177,7 +1189,7 @@ class SoftmaxPrimitive(BasePrimitive):
softmax_out_type = ir.RankedTensorType(softmax_out.type)
softmax_out_shape = softmax_out_type.shape
out_types = [ir.RankedTensorType.get(softmax_out_shape, softmax_out_type.element_type)]
out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
operands = [dz, softmax_out]
operand_shapes = [dz_shape, softmax_out_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
......@@ -1211,27 +1223,44 @@ class SoftmaxPrimitive(BasePrimitive):
out_bdims = softmax_out_bdim
return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims
@staticmethod
def backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
@classmethod
def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
"""
softmax_backward infer_sharding_from_operands
"""
del scale_factor, result_infos # Unused.
softmax_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec))
dz_spec = get_padded_spec(arg_infos[0])
if dz_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
return dx_sharding
@staticmethod
def backward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
@classmethod
def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
"""
softmax_backward partition
"""
del result_infos
dz_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
softmax_out_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
dx_spec = softmax_out_spec
arg_shardings = (dz_spec, softmax_out_spec)
out_shardings = dx_spec
dz_spec = get_padded_spec(arg_infos[0])
softmax_out_spec = get_padded_spec(arg_infos[1])
if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
dx_sharding = dz_sharding
arg_shardings = (dz_sharding, softmax_out_sharding)
out_shardings = dx_sharding
impl = partial(impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings
......@@ -1296,13 +1325,15 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, scale_factor,
mesh, arg_infos, result_infos)
return ScaledSoftmaxFwdPrimitive.forward_partition(
ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
register_primitive(ScaledSoftmaxFwdPrimitive)
......@@ -1370,13 +1401,15 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, scale_factor,
mesh, arg_infos, result_infos)
return ScaledSoftmaxBwdPrimitive.backward_partition(
ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
register_primitive(ScaledSoftmaxBwdPrimitive)
......@@ -1505,20 +1538,15 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
del scale_factor, result_infos # Unused.
logits_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
return out_sharding
return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos,result_infos
)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
del result_infos
logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
mask_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = (logits_spec, mask_spec)
out_shardings = logits_spec
impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings
return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
......@@ -1589,13 +1617,15 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_partition(ScaledMaskedSoftmaxBwdPrimitive.impl,
scale_factor, mesh, arg_infos, result_infos)
return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
......@@ -1676,13 +1706,16 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_partition(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
scale_factor, mesh, arg_infos, result_infos)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh,
arg_infos, result_infos
)
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
......@@ -1753,13 +1786,16 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_partition(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
scale_factor, mesh, arg_infos, result_infos)
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh,
arg_infos, result_infos
)
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment