Unverified Commit bee4649c authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Fix softmax aux shapes for packed/THD format (#1575)



* Fix softmax shape for THD format.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 4f33ece4
...@@ -295,7 +295,10 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -295,7 +295,10 @@ class FusedAttnFwdPrimitive(BasePrimitive):
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
# cuDNN 9.6 reduces the required softmax shape # cuDNN 9.6 reduces the required softmax shape
if get_cudnn_version() >= (9, 6, 0): if get_cudnn_version() >= (9, 6, 0):
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) if config.qkv_layout.is_thd():
softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
else: else:
softmax_shape = ( softmax_shape = (
*batch_shape, *batch_shape,
...@@ -607,28 +610,49 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -607,28 +610,49 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
del result_infos del result_infos
q_spec = get_padded_spec(arg_infos[0]) q_spec = get_padded_spec(arg_infos[0])
# when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
# otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
if config.qkv_layout.is_qkvpacked(): if config.qkv_layout.is_qkvpacked():
# q_spec = (...batch, q_seqlen, 3, head, hidden) # q_spec = (...batch, q_seqlen, 3, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
softmax_aux_sharding = NamedSharding( if not is_packed_softmax:
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) softmax_aux_sharding = NamedSharding(
) mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
)
else:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None)
)
elif config.qkv_layout.is_kvpacked(): elif config.qkv_layout.is_kvpacked():
# q_spec = (...batch, q_seqlen, head, hidden) # q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding( if not is_packed_softmax:
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) softmax_aux_sharding = NamedSharding(
) mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
else:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
)
elif config.qkv_layout.is_separate(): elif config.qkv_layout.is_separate():
# q_spec = (...batch, q_seqlen, head, hidden) # q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding( if not is_packed_softmax:
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) softmax_aux_sharding = NamedSharding(
) mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
else:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
)
else: else:
raise ValueError(f"Unsupported {config.qkv_layout=}") raise ValueError(f"Unsupported {config.qkv_layout=}")
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding) return (out_sharding, softmax_aux_sharding, rng_state_sharding)
...@@ -2236,7 +2260,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2236,7 +2260,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
subblock_config, subblock_config,
) )
# TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1)) softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))
def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step): def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
...@@ -2272,8 +2295,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2272,8 +2295,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
carry = scan_kv_block(i, carry) carry = scan_kv_block(i, carry)
(_, _, _, output, softmax_aux) = carry (_, _, _, output, softmax_aux) = carry
softmax_aux = softmax_aux.reshape((batch, head, q_max_seqlen, 1))
return output.astype(q.dtype), softmax_aux, rng_state return output.astype(q.dtype), softmax_aux, rng_state
return mesh, fwd_impl, out_shardings, arg_shardings return mesh, fwd_impl, out_shardings, arg_shardings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment