Unverified Commit 728e335f authored by George Karpenkov's avatar George Karpenkov Committed by GitHub
Browse files

Fix types for forward attention for JAX. (#704)



Bias and seed can both be None, type checking is failed otherwise.
Signed-off-by: default avatarGeorge Karpenkov <george@metaworld.me>
parent d8f678dc
...@@ -2075,9 +2075,10 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2075,9 +2075,10 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(SelfFusedAttnFwdPrimitive) register_primitive(SelfFusedAttnFwdPrimitive)
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray, seed: jnp.ndarray, def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray | None, seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, seed: jnp.ndarray | None, attn_bias_type: NVTE_Bias_Type,
scaling_factor: float, dropout_probability: float, is_training: bool): attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
""" """
Wrapper for TE self fused attention fwd Wrapper for TE self fused attention fwd
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
......
...@@ -66,9 +66,10 @@ def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, ...@@ -66,9 +66,10 @@ def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type,
max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available() max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
scaling_factor: float, dropout_probability: float, is_training: bool): attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
""" """
Self fused attention wrapper Self fused attention wrapper
""" """
...@@ -86,19 +87,22 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed ...@@ -86,19 +87,22 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
scaling_factor: float, dropout_probability: float, is_training: bool): attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training) scaling_factor, dropout_probability, is_training)
return output return output
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, mask: jnp.ndarray, seed: jnp.ndarray | None,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_bias_type: AttnBiasType,
dropout_probability: float, is_training: bool): attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float,
is_training: bool):
if mask is None: if mask is None:
batch, seqlen, *_ = qkv.shape batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32) actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment