Unverified Commit 47ab4a74 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Add Transformer Layer tests for pre_scale_bias and post_scale_bias (#2104)



Add Transformer Layer tests for pre_scale_bias and post_scale_bias
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
parent 2e23ad71
...@@ -263,6 +263,16 @@ ATTRS = [ ...@@ -263,6 +263,16 @@ ATTRS = [
_KEY_OF_RELATIVE_EMBEDDING: False, _KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_WINDOW_SIZE: (2, 2), _KEY_OF_WINDOW_SIZE: (2, 2),
}, },
# attrs29
{
_KEY_OF_RELATIVE_EMBEDDING: True,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "pre_scale_bias",
},
# attrs30
{
_KEY_OF_RELATIVE_EMBEDDING: True,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
},
] ]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment