Commit bb6f6145 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Slightly refactor recycling code to improve FP16 stability

parent b47138dc
......@@ -197,23 +197,24 @@ class AlphaFold(nn.Module):
feats["msa_feat"],
)
# Inject information from previous recycling iterations
if _recycle:
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.c_m),
requires_grad=False,
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.c_z),
requires_grad=False,
)
# [*, N, 3]
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
......@@ -226,11 +227,19 @@ class AlphaFold(nn.Module):
x_prev,
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if(not _recycle):
m_1_prev_emb *= 0
z_prev_emb *= 0
# [*, S_c, N, C_m]
m[..., 0, :, :] = m[..., 0, :, :] + m_1_prev_emb
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z = z + z_prev_emb
z += z_prev_emb
# Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
......@@ -408,11 +417,12 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug (PyTorch issue #65766)
if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
torch.clear_autocast_cache()
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment