"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "1f5b7872b03e9e3d42801872bc59681ef36357b5"
Commit bb6f6145 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Slightly refactor recycling code to improve FP16 stability

parent b47138dc
......@@ -197,43 +197,52 @@ class AlphaFold(nn.Module):
feats["msa_feat"],
)
# Inject information from previous recycling iterations
if _recycle:
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.c_m),
)
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.c_m),
requires_grad=False,
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.c_z),
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.c_z),
requires_grad=False,
)
# [*, N, 3]
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
)
# [*, N, 3]
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
)
# [*, S_c, N, C_m]
m[..., 0, :, :] = m[..., 0, :, :] + m_1_prev_emb
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if(not _recycle):
m_1_prev_emb *= 0
z_prev_emb *= 0
# [*, N, N, C_z]
z = z + z_prev_emb
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# [*, N, N, C_z]
z += z_prev_emb
# Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
......@@ -408,11 +417,12 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug (PyTorch issue #65766)
if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
torch.clear_autocast_cache()
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment