Commit bf3eed6e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove dependency on training-time feature

parent 267b1bfd
...@@ -169,7 +169,7 @@ class AlphaFold(nn.Module): ...@@ -169,7 +169,7 @@ class AlphaFold(nn.Module):
return ret return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev): def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
# Establish constants # Establish constants
chunk_size = ( chunk_size = (
self.globals.train_chunk_size self.globals.train_chunk_size
...@@ -202,7 +202,7 @@ class AlphaFold(nn.Module): ...@@ -202,7 +202,7 @@ class AlphaFold(nn.Module):
) )
# Inject information from previous recycling iterations # Inject information from previous recycling iterations
if feats["no_recycling_iters"] > 0: if _recycle is True:
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]: if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m] # [*, N, C_m]
...@@ -420,6 +420,7 @@ class AlphaFold(nn.Module): ...@@ -420,6 +420,7 @@ class AlphaFold(nn.Module):
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
_recycle=(num_iters > 1)
) )
# Run auxiliary heads # Run auxiliary heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment