Commit bb6f6145 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Slightly refactor recycling code to improve FP16 stability

parent b47138dc
...@@ -197,43 +197,52 @@ class AlphaFold(nn.Module): ...@@ -197,43 +197,52 @@ class AlphaFold(nn.Module):
feats["msa_feat"], feats["msa_feat"],
) )
# Inject information from previous recycling iterations # Initialize the recycling embeddings, if needs be
if _recycle: if None in [m_1_prev, z_prev, x_prev]:
# Initialize the recycling embeddings, if needs be # [*, N, C_m]
if None in [m_1_prev, z_prev, x_prev]: m_1_prev = m.new_zeros(
# [*, N, C_m] (*batch_dims, n, self.config.input_embedder.c_m),
m_1_prev = m.new_zeros( requires_grad=False,
(*batch_dims, n, self.config.input_embedder.c_m), )
)
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = z.new_zeros( z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.c_z), (*batch_dims, n, n, self.config.input_embedder.c_z),
) requires_grad=False,
)
# [*, N, 3] # [*, N, 3]
x_prev = z.new_zeros( x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3), (*batch_dims, n, residue_constants.atom_type_num, 3),
) requires_grad=False,
)
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None) x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
) )
# [*, S_c, N, C_m] # If the number of recycling iterations is 0, skip recycling
m[..., 0, :, :] = m[..., 0, :, :] + m_1_prev_emb # altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if(not _recycle):
m_1_prev_emb *= 0
z_prev_emb *= 0
# [*, N, N, C_z] # [*, S_c, N, C_m]
z = z + z_prev_emb m[..., 0, :, :] += m_1_prev_emb
# Possibly prevents memory fragmentation # [*, N, N, C_z]
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb z += z_prev_emb
# Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
...@@ -408,11 +417,12 @@ class AlphaFold(nn.Module): ...@@ -408,11 +417,12 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1) is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug (PyTorch issue #65766)
if is_final_iter: if is_final_iter:
self._enable_activation_checkpointing() self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
torch.clear_autocast_cache() torch.clear_autocast_cache()
# Run the next iteration of the model # Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, feats,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment