Commit de0fa7b1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix tensor casting, improve TorchScript compatibility

parent 7d53297c
......@@ -430,7 +430,7 @@ class ExtraMSAStack(nn.Module):
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
"""
_, z, _ = self.stack(
m,
z,
......
......@@ -43,7 +43,6 @@ from openfold.model.template import (
from openfold.utils.loss import (
compute_plddt,
)
from openfold.utils.tensor_utils import (
dict_multimap,
tensor_tree_map,
......@@ -162,7 +161,7 @@ class AlphaFold(nn.Module):
z,
template_mask=batch["template_mask"]
)
t = t * torch.sum(batch["template_mask"]) > 0
t = t * (torch.sum(batch["template_mask"]) > 0)
return {
"template_angle_embedding": a,
......@@ -318,6 +317,22 @@ class AlphaFold(nn.Module):
return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
self.extra_msa_stack.stack.blocks_per_ckpt = None
def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt
)
self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt
)
self.extra_msa_stack.stack.blocks_per_ckpt = (
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt
)
def forward(self, batch):
"""
Args:
......@@ -368,9 +383,12 @@ class AlphaFold(nn.Module):
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing until the final recycling layer
self._disable_activation_checkpointing()
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle
......@@ -379,9 +397,11 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == self.config.no_cycles - 1)
if(self.training and is_final_iter):
self._enable_activation_checkpointing()
with torch.set_grad_enabled(self.training and is_final_iter):
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
feats, m_1_prev, z_prev, x_prev,
)
outputs.update(self.aux_heads(outputs))
......
......@@ -313,7 +313,7 @@ class GlobalAttention(nn.Module):
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5)
q = q * (self.c_hidden ** (-0.5))
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
......
......@@ -70,7 +70,8 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = checkpoint(chunker(s, e), *args)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args)
return args
......@@ -279,7 +279,8 @@ def atom37_to_torsion_angles(
)
torsion_angles_sin_cos = torch.stack(
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1)
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
)
denom = torch.sqrt(
torch.sum(
torch.square(torsion_angles_sin_cos), dim=-1, keepdims=True
......@@ -336,7 +337,7 @@ def atom37_to_frames(
restype_rigidgroup_mask = torch.zeros(
(*aatype.shape[:-1], 21, 8),
dtype=torch.float,
dtype=all_atom_mask.dtype,
device=aatype.device,
requires_grad=False
)
......@@ -390,14 +391,16 @@ def atom37_to_frames(
)
gt_atoms_exist = batched_gather(
all_atom_mask.float(),
all_atom_mask,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1])
)
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(3, device=aatype.device, requires_grad=False)
rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device, requires_grad=False
)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
......@@ -408,7 +411,7 @@ def atom37_to_frames(
*((1,) * batch_dims), 21, 8
)
restype_rigidgroup_rots = torch.eye(
3, device=aatype.device, requires_grad=False
3, dtype=all_atom_mask.dtype, device=aatype.device, requires_grad=False
)
restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots,
......
......@@ -1385,9 +1385,9 @@ class AlphaFoldLoss(nn.Module):
for k,loss_fn in loss_fns.items():
weight = self.config[k].weight
if(weight):
print(k)
loss = loss_fn()
print(loss)
#print(k)
#print(loss)
cum_loss += weight * loss
return cum_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment