Commit bb6f6145 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Slightly refactor recycling code to improve FP16 stability

parent b47138dc
...@@ -197,23 +197,24 @@ class AlphaFold(nn.Module): ...@@ -197,23 +197,24 @@ class AlphaFold(nn.Module):
feats["msa_feat"], feats["msa_feat"],
) )
# Inject information from previous recycling iterations
if _recycle:
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]: if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m] # [*, N, C_m]
m_1_prev = m.new_zeros( m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.c_m), (*batch_dims, n, self.config.input_embedder.c_m),
requires_grad=False,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = z.new_zeros( z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.c_z), (*batch_dims, n, n, self.config.input_embedder.c_z),
requires_grad=False,
) )
# [*, N, 3] # [*, N, 3]
x_prev = z.new_zeros( x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3), (*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
) )
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None) x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
...@@ -226,11 +227,19 @@ class AlphaFold(nn.Module): ...@@ -226,11 +227,19 @@ class AlphaFold(nn.Module):
x_prev, x_prev,
) )
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if(not _recycle):
m_1_prev_emb *= 0
z_prev_emb *= 0
# [*, S_c, N, C_m] # [*, S_c, N, C_m]
m[..., 0, :, :] = m[..., 0, :, :] + m_1_prev_emb m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + z_prev_emb z += z_prev_emb
# Possibly prevents memory fragmentation # Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
...@@ -408,11 +417,12 @@ class AlphaFold(nn.Module): ...@@ -408,11 +417,12 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1) is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug (PyTorch issue #65766)
if is_final_iter: if is_final_iter:
self._enable_activation_checkpointing() self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
torch.clear_autocast_cache() torch.clear_autocast_cache()
# Run the next iteration of the model # Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, feats,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment