Commit 16bab4d3 authored by Tim Green's avatar Tim Green Committed by Copybara-Service
Browse files

Remove unused bias arg in GlobalAttention.

PiperOrigin-RevId: 430053169
Change-Id: Ia7461bc23d6e0c94f5bef8c1ded1c8306193e67e
parent 89a6875f
...@@ -632,7 +632,7 @@ class GlobalAttention(hk.Module): ...@@ -632,7 +632,7 @@ class GlobalAttention(hk.Module):
self.global_config = global_config self.global_config = global_config
self.output_dim = output_dim self.output_dim = output_dim
def __call__(self, q_data, m_data, q_mask, bias): def __call__(self, q_data, m_data, q_mask):
"""Builds GlobalAttention module. """Builds GlobalAttention module.
Arguments: Arguments:
...@@ -643,7 +643,6 @@ class GlobalAttention(hk.Module): ...@@ -643,7 +643,6 @@ class GlobalAttention(hk.Module):
q_mask: A binary mask for q_data with zeros in the padded sequence q_mask: A binary mask for q_data with zeros in the padded sequence
elements and ones otherwise. Size [batch_size, N_queries, q_channels] elements and ones otherwise. Size [batch_size, N_queries, q_channels]
(or broadcastable to this shape). (or broadcastable to this shape).
bias: A bias for the attention.
Returns: Returns:
A float32 tensor of size [batch_size, N_queries, output_dim]. A float32 tensor of size [batch_size, N_queries, output_dim].
...@@ -880,7 +879,7 @@ class MSAColumnGlobalAttention(hk.Module): ...@@ -880,7 +879,7 @@ class MSAColumnGlobalAttention(hk.Module):
msa_act = mapping.inference_subbatch( msa_act = mapping.inference_subbatch(
attn_mod, attn_mod,
self.global_config.subbatch_size, self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, msa_mask, bias], batched_args=[msa_act, msa_act, msa_mask],
nonbatched_args=[], nonbatched_args=[],
low_memory=not is_training) low_memory=not is_training)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment