Unverified Commit e1e83598 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[JAX] Debug distributed attention tests (#1038)



* Remove extra args to fused attention func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add missing arg to fused attention func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 08b49976
......@@ -124,12 +124,9 @@ class TestDistributedSelfAttn:
bias,
mask,
None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=QKVLayout.BS3HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
......@@ -260,12 +257,9 @@ class TestDistributedCrossAttn:
None,
mask,
None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=QKVLayout.BSHD_BS2HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment