Commit de0fa7b1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix tensor casting, improve TorchScript compatibility

parent 7d53297c
...@@ -430,7 +430,7 @@ class ExtraMSAStack(nn.Module): ...@@ -430,7 +430,7 @@ class ExtraMSAStack(nn.Module):
Optional [*, N_res, N_res] pair mask Optional [*, N_res, N_res] pair mask
Returns: Returns:
[*, N_res, N_res, C_z] pair update [*, N_res, N_res, C_z] pair update
""" """
_, z, _ = self.stack( _, z, _ = self.stack(
m, m,
z, z,
......
...@@ -43,7 +43,6 @@ from openfold.model.template import ( ...@@ -43,7 +43,6 @@ from openfold.model.template import (
from openfold.utils.loss import ( from openfold.utils.loss import (
compute_plddt, compute_plddt,
) )
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
tensor_tree_map, tensor_tree_map,
...@@ -162,7 +161,7 @@ class AlphaFold(nn.Module): ...@@ -162,7 +161,7 @@ class AlphaFold(nn.Module):
z, z,
template_mask=batch["template_mask"] template_mask=batch["template_mask"]
) )
t = t * torch.sum(batch["template_mask"]) > 0 t = t * (torch.sum(batch["template_mask"]) > 0)
return { return {
"template_angle_embedding": a, "template_angle_embedding": a,
...@@ -318,6 +317,22 @@ class AlphaFold(nn.Module): ...@@ -318,6 +317,22 @@ class AlphaFold(nn.Module):
return outputs, m_1_prev, z_prev, x_prev return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
self.extra_msa_stack.stack.blocks_per_ckpt = None
def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt
)
self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt
)
self.extra_msa_stack.stack.blocks_per_ckpt = (
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt
)
def forward(self, batch): def forward(self, batch):
""" """
Args: Args:
...@@ -368,9 +383,12 @@ class AlphaFold(nn.Module): ...@@ -368,9 +383,12 @@ class AlphaFold(nn.Module):
"template_pseudo_beta_mask" ([*, N_templ, N_res]) "template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask Pseudo-beta mask
""" """
# Recycling embeddings # Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing until the final recycling layer
self._disable_activation_checkpointing()
# Main recycling loop # Main recycling loop
for cycle_no in range(self.config.no_cycles): for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
...@@ -379,9 +397,11 @@ class AlphaFold(nn.Module): ...@@ -379,9 +397,11 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == self.config.no_cycles - 1) is_final_iter = (cycle_no == self.config.no_cycles - 1)
if(self.training and is_final_iter):
self._enable_activation_checkpointing()
with torch.set_grad_enabled(self.training and is_final_iter): with torch.set_grad_enabled(self.training and is_final_iter):
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev, feats, m_1_prev, z_prev, x_prev,
) )
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
......
...@@ -313,7 +313,7 @@ class GlobalAttention(nn.Module): ...@@ -313,7 +313,7 @@ class GlobalAttention(nn.Module):
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
q = self.linear_q(q) q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5) q = q * (self.c_hidden ** (-0.5))
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1)) q = q.view(q.shape[:-1] + (self.no_heads, -1))
......
...@@ -70,7 +70,8 @@ def checkpoint_blocks( ...@@ -70,7 +70,8 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt): for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt e = s + blocks_per_ckpt
args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args) args = checkpoint(chunker(s, e), *args)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args) args = wrap(args)
return args return args
...@@ -279,7 +279,8 @@ def atom37_to_torsion_angles( ...@@ -279,7 +279,8 @@ def atom37_to_torsion_angles(
) )
torsion_angles_sin_cos = torch.stack( torsion_angles_sin_cos = torch.stack(
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1) [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
)
denom = torch.sqrt( denom = torch.sqrt(
torch.sum( torch.sum(
torch.square(torsion_angles_sin_cos), dim=-1, keepdims=True torch.square(torsion_angles_sin_cos), dim=-1, keepdims=True
...@@ -336,7 +337,7 @@ def atom37_to_frames( ...@@ -336,7 +337,7 @@ def atom37_to_frames(
restype_rigidgroup_mask = torch.zeros( restype_rigidgroup_mask = torch.zeros(
(*aatype.shape[:-1], 21, 8), (*aatype.shape[:-1], 21, 8),
dtype=torch.float, dtype=all_atom_mask.dtype,
device=aatype.device, device=aatype.device,
requires_grad=False requires_grad=False
) )
...@@ -390,14 +391,16 @@ def atom37_to_frames( ...@@ -390,14 +391,16 @@ def atom37_to_frames(
) )
gt_atoms_exist = batched_gather( gt_atoms_exist = batched_gather(
all_atom_mask.float(), all_atom_mask,
residx_rigidgroup_base_atom37_idx, residx_rigidgroup_base_atom37_idx,
dim=-1, dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1]) no_batch_dims=len(all_atom_mask.shape[:-1])
) )
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(3, device=aatype.device, requires_grad=False) rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device, requires_grad=False
)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
...@@ -408,7 +411,7 @@ def atom37_to_frames( ...@@ -408,7 +411,7 @@ def atom37_to_frames(
*((1,) * batch_dims), 21, 8 *((1,) * batch_dims), 21, 8
) )
restype_rigidgroup_rots = torch.eye( restype_rigidgroup_rots = torch.eye(
3, device=aatype.device, requires_grad=False 3, dtype=all_atom_mask.dtype, device=aatype.device, requires_grad=False
) )
restype_rigidgroup_rots = torch.tile( restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots, restype_rigidgroup_rots,
......
...@@ -1385,9 +1385,9 @@ class AlphaFoldLoss(nn.Module): ...@@ -1385,9 +1385,9 @@ class AlphaFoldLoss(nn.Module):
for k,loss_fn in loss_fns.items(): for k,loss_fn in loss_fns.items():
weight = self.config[k].weight weight = self.config[k].weight
if(weight): if(weight):
print(k)
loss = loss_fn() loss = loss_fn()
print(loss) #print(k)
#print(loss)
cum_loss += weight * loss cum_loss += weight * loss
return cum_loss return cum_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment