Unverified Commit dfbf4dde authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Fix issues when mask/sequence_descriptor is None (#1477)



Fix issues when mask/sequence_descriptor is None
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 45e9d8b6
......@@ -556,13 +556,16 @@ class FusedAttnRunner:
else:
match self.seq_desc_format:
case SeqDescFormat.Mask:
self.sequence_desciptor = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
if self.attn_mask_type == AttnMaskType.NO_MASK:
self.sequence_desciptor = None
else:
self.sequence_desciptor = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens(
(
......
......@@ -950,7 +950,7 @@ def fused_attn(
AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
QKVLayout.T3HD, 0.125, 0, True, 3)
"""
if isinstance(sequence_descriptor, jnp.ndarray):
if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray):
warnings.warn(
"Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. "
+ "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment