"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "08fcd7e93ba5df3093a8b54fe79e0895fe7a5f15"
Commit f34bef8e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Sidestep AMP bug

parent 304b5ff7
......@@ -235,6 +235,9 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z = z + z_prev_emb
# This can matter during inference when N_res is very large
del m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
template_feats = {
......@@ -332,16 +335,6 @@ class AlphaFold(nn.Module):
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt
)
def _disable_grad(self):
vals = [p.requires_grad for p in self.parameters()]
for p in self.parameters():
p.requires_grad_(False)
return vals
def _enable_grad(self, vals):
for p, v in zip(self.parameters(), vals):
p.requires_grad_(v)
def forward(self, batch):
"""
Args:
......@@ -394,27 +387,25 @@ class AlphaFold(nn.Module):
"""
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing until the final recycling layer
self._disable_activation_checkpointing()
grad_vals = self._disable_grad()
is_grad_enabled = torch.is_grad_enabled()
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == (self.config.no_cycles - 1))
if(is_final_iter):
self._enable_activation_checkpointing()
self._enable_grad(grad_vals)
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug discussed in pytorch issue #65766
if(is_final_iter and torch.is_autocast_enabled()):
torch.clear_autocast_cache()
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
)
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment