Unverified Commit b6020e3b authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[JAX] Fix bug with pre scale bias (#2300)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent 77a00635
...@@ -197,6 +197,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -197,6 +197,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
fused_scale_factor = scale_factor fused_scale_factor = scale_factor
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias attn_weights += bias
bias = None
def apply_swa_mask(original_mask: Array) -> Array: def apply_swa_mask(original_mask: Array) -> Array:
"""Apply the sliding window mask to a given mask""" """Apply the sliding window mask to a given mask"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment