Unverified Commit dfbf4dde authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Fix issues when mask/sequence_descriptor is None (#1477)



Fix issues when mask/sequence_descriptor is None
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 45e9d8b6
...@@ -556,6 +556,9 @@ class FusedAttnRunner: ...@@ -556,6 +556,9 @@ class FusedAttnRunner:
else: else:
match self.seq_desc_format: match self.seq_desc_format:
case SeqDescFormat.Mask: case SeqDescFormat.Mask:
if self.attn_mask_type == AttnMaskType.NO_MASK:
self.sequence_desciptor = None
else:
self.sequence_desciptor = make_mask( self.sequence_desciptor = make_mask(
self.segment_ids_q, self.segment_ids_q,
self.segment_ids_kv, self.segment_ids_kv,
......
...@@ -950,7 +950,7 @@ def fused_attn( ...@@ -950,7 +950,7 @@ def fused_attn(
AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
QKVLayout.T3HD, 0.125, 0, True, 3) QKVLayout.T3HD, 0.125, 0, True, 3)
""" """
if isinstance(sequence_descriptor, jnp.ndarray): if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray):
warnings.warn( warnings.warn(
"Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. " "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. "
+ "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.", + "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment