Commit f34bef8e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Sidestep AMP bug

parent 304b5ff7
...@@ -235,6 +235,9 @@ class AlphaFold(nn.Module): ...@@ -235,6 +235,9 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + z_prev_emb z = z + z_prev_emb
# This can matter during inference when N_res is very large
del m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled): if(self.config.template.enabled):
template_feats = { template_feats = {
...@@ -332,16 +335,6 @@ class AlphaFold(nn.Module): ...@@ -332,16 +335,6 @@ class AlphaFold(nn.Module):
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt self.config.extra_msa.extra_msa_stack.blocks_per_ckpt
) )
def _disable_grad(self):
vals = [p.requires_grad for p in self.parameters()]
for p in self.parameters():
p.requires_grad_(False)
return vals
def _enable_grad(self, vals):
for p, v in zip(self.parameters(), vals):
p.requires_grad_(v)
def forward(self, batch): def forward(self, batch):
""" """
Args: Args:
...@@ -394,27 +387,25 @@ class AlphaFold(nn.Module): ...@@ -394,27 +387,25 @@ class AlphaFold(nn.Module):
""" """
# Initialize recycling embeddings # Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing until the final recycling layer is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
grad_vals = self._disable_grad()
# Main recycling loop # Main recycling loop
for cycle_no in range(self.config.no_cycles): for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == (self.config.no_cycles - 1)) is_final_iter = (cycle_no == (self.config.no_cycles - 1))
if(is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
self._enable_activation_checkpointing() # Sidestep AMP bug discussed in pytorch issue #65766
self._enable_grad(grad_vals) if(is_final_iter and torch.is_autocast_enabled()):
torch.clear_autocast_cache()
# Run the next iteration of the model # Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev, feats, m_1_prev, z_prev, x_prev,
) )
# Run auxiliary heads # Run auxiliary heads
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment