Unverified Commit bee4649c authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Fix softmax aux shapes for packed/THD format (#1575)



* Fix softmax shape for THD format.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 4f33ece4
......@@ -295,6 +295,9 @@ class FusedAttnFwdPrimitive(BasePrimitive):
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
# cuDNN 9.6 reduces the required softmax shape
if get_cudnn_version() >= (9, 6, 0):
if config.qkv_layout.is_thd():
softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
else:
softmax_shape = (
......@@ -607,28 +610,49 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
# when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
# otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
if config.qkv_layout.is_qkvpacked():
# q_spec = (...batch, q_seqlen, 3, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
if not is_packed_softmax:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
)
else:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None)
)
elif config.qkv_layout.is_kvpacked():
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
if not is_packed_softmax:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
else:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
)
elif config.qkv_layout.is_separate():
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
if not is_packed_softmax:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
else:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
)
else:
raise ValueError(f"Unsupported {config.qkv_layout=}")
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
......@@ -2236,7 +2260,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
subblock_config,
)
# TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))
def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
......@@ -2272,8 +2295,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
carry = scan_kv_block(i, carry)
(_, _, _, output, softmax_aux) = carry
softmax_aux = softmax_aux.reshape((batch, head, q_max_seqlen, 1))
return output.astype(q.dtype), softmax_aux, rng_state
return mesh, fwd_impl, out_shardings, arg_shardings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment