Commit eb49136d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Finish purge of in-place ops, get grads working, add TM

parent 1d47c1e7
...@@ -23,8 +23,8 @@ blocks_per_ckpt = mlc.FieldReference(1, field_type=int) ...@@ -23,8 +23,8 @@ blocks_per_ckpt = mlc.FieldReference(1, field_type=int)
chunk_size = mlc.FieldReference(None, field_type=int) chunk_size = mlc.FieldReference(None, field_type=int)
aux_distogram_bins = mlc.FieldReference(64) aux_distogram_bins = mlc.FieldReference(64)
eps = 1e-4 eps = 1e-8
inf = 1e4 inf = 1e8
config = mlc.ConfigDict({ config = mlc.ConfigDict({
"model": { "model": {
...@@ -33,7 +33,7 @@ config = mlc.ConfigDict({ ...@@ -33,7 +33,7 @@ config = mlc.ConfigDict({
"c_t": c_t, "c_t": c_t,
"c_e": c_e, "c_e": c_e,
"c_s": c_s, "c_s": c_s,
"no_cycles": 4, "no_cycles": 2,#4,
"_mask_trans": False, "_mask_trans": False,
"input_embedder": { "input_embedder": {
"tf_dim": 22, "tf_dim": 22,
...@@ -117,7 +117,7 @@ config = mlc.ConfigDict({ ...@@ -117,7 +117,7 @@ config = mlc.ConfigDict({
"inf": inf,#1e9, "inf": inf,#1e9,
"eps": eps,#1e-10, "eps": eps,#1e-10,
}, },
"enabled": False,#True, "enabled": True,
}, },
"evoformer_stack": { "evoformer_stack": {
"c_m": c_m, "c_m": c_m,
...@@ -147,7 +147,7 @@ config = mlc.ConfigDict({ ...@@ -147,7 +147,7 @@ config = mlc.ConfigDict({
"no_qk_points": 4, "no_qk_points": 4,
"no_v_points": 8, "no_v_points": 8,
"dropout_rate": 0.1, "dropout_rate": 0.1,
"no_blocks": 8, "no_blocks": 2,#8,
"no_transition_layers": 1, "no_transition_layers": 1,
"no_resnet_blocks": 2, "no_resnet_blocks": 2,
"no_angles": 7, "no_angles": 7,
...@@ -165,10 +165,10 @@ config = mlc.ConfigDict({ ...@@ -165,10 +165,10 @@ config = mlc.ConfigDict({
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
}, },
"tm_score": { "tm": {
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
"enabled": False, "enabled": True,
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_m": c_m,
...@@ -239,6 +239,14 @@ config = mlc.ConfigDict({ ...@@ -239,6 +239,14 @@ config = mlc.ConfigDict({
"eps": eps,#1e-6, "eps": eps,#1e-6,
"weight": 0., "weight": 0.,
}, },
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps,#1e-8,
"weight": 1.0,
},
"eps": eps, "eps": eps,
}, },
}) })
...@@ -53,7 +53,7 @@ class Dropout(nn.Module): ...@@ -53,7 +53,7 @@ class Dropout(nn.Module):
if(self.batch_dim is not None): if(self.batch_dim is not None):
for bd in self.batch_dim: for bd in self.batch_dim:
shape[bd] = 1 shape[bd] = 1
mask = x.new_ones(shape, requires_grad=False) mask = x.new_ones(shape)
mask = self.dropout(mask) mask = self.dropout(mask)
x = x * mask x = x * mask
return x return x
......
...@@ -194,7 +194,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -194,7 +194,6 @@ class RecyclingEmbedder(nn.Module):
self.max_bin, self.max_bin,
self.no_bins, self.no_bins,
dtype=x.dtype, dtype=x.dtype,
requires_grad=False,
device=x.device device=x.device
) )
......
...@@ -40,9 +40,9 @@ class AuxiliaryHeads(nn.Module): ...@@ -40,9 +40,9 @@ class AuxiliaryHeads(nn.Module):
**config["experimentally_resolved"], **config["experimentally_resolved"],
) )
if(config.tm_score.enabled): if(config.tm.enabled):
self.tm_score = TMScoreHead( self.tm = TMScoreHead(
**config["tm_score"], **config.tm,
) )
self.config = config self.config = config
...@@ -68,9 +68,9 @@ class AuxiliaryHeads(nn.Module): ...@@ -68,9 +68,9 @@ class AuxiliaryHeads(nn.Module):
experimentally_resolved_logits experimentally_resolved_logits
) )
if(self.config.tm_score.enabled): if(self.config.tm.enabled):
tm_score_logits = self.tm_score(outputs["pair"]) tm_logits = self.tm(outputs["pair"])
aux_out["tm_score_logits"] = tm_score_logits aux_out["tm_logits"] = tm_logits
return aux_out return aux_out
......
...@@ -115,10 +115,6 @@ class AlphaFold(nn.Module): ...@@ -115,10 +115,6 @@ class AlphaFold(nn.Module):
batch, batch,
) )
#tensor_dtype = (
# single_template_feats["template_all_atom_masks"].dtype
#)
# Build template angle feats # Build template angle feats
angle_feats = atom37_to_torsion_angles( angle_feats = atom37_to_torsion_angles(
single_template_feats["template_aatype"], single_template_feats["template_aatype"],
...@@ -127,10 +123,6 @@ class AlphaFold(nn.Module): ...@@ -127,10 +123,6 @@ class AlphaFold(nn.Module):
eps=self.config.template.eps, eps=self.config.template.eps,
) )
#angle_feats = tensor_tree_map(
# lambda t: t.type(tensor_dtype), angle_feats
#)
template_angle_feat = build_template_angle_feat( template_angle_feat = build_template_angle_feat(
angle_feats, angle_feats,
single_template_feats["template_aatype"], single_template_feats["template_aatype"],
...@@ -211,19 +203,16 @@ class AlphaFold(nn.Module): ...@@ -211,19 +203,16 @@ class AlphaFold(nn.Module):
# [*, N, C_m] # [*, N, C_m]
m_1_prev = m.new_zeros( m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.c_m), (*batch_dims, n, self.config.c_m),
requires_grad=False,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = z.new_zeros( z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.c_z), (*batch_dims, n, n, self.config.c_z),
requires_grad=False,
) )
# [*, N, 3] # [*, N, 3]
x_prev = z.new_zeros( x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3), (*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
) )
x_prev = pseudo_beta_fn( x_prev = pseudo_beta_fn(
...@@ -241,7 +230,7 @@ class AlphaFold(nn.Module): ...@@ -241,7 +230,7 @@ class AlphaFold(nn.Module):
) )
# [*, S_c, N, C_m] # [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb m[..., 0, :, :] = m[..., 0, :, :] + m_1_prev_emb
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + z_prev_emb z = z + z_prev_emb
...@@ -312,6 +301,7 @@ class AlphaFold(nn.Module): ...@@ -312,6 +301,7 @@ class AlphaFold(nn.Module):
outputs["sm"]["positions"][-1], feats outputs["sm"]["positions"][-1], feats
) )
outputs["final_atom_mask"] = feats["atom37_atom_exists"] outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
# Save embeddings for use during the next recycling iteration # Save embeddings for use during the next recycling iteration
...@@ -342,6 +332,16 @@ class AlphaFold(nn.Module): ...@@ -342,6 +332,16 @@ class AlphaFold(nn.Module):
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt self.config.extra_msa.extra_msa_stack.blocks_per_ckpt
) )
def _disable_grad(self):
vals = [p.requires_grad for p in self.parameters()]
for p in self.parameters():
p.requires_grad_(False)
return vals
def _enable_grad(self, vals):
for p, v in zip(self.parameters(), vals):
p.requires_grad_(v)
def forward(self, batch): def forward(self, batch):
""" """
Args: Args:
...@@ -391,12 +391,13 @@ class AlphaFold(nn.Module): ...@@ -391,12 +391,13 @@ class AlphaFold(nn.Module):
for which C_alpha is used instead) for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res]) "template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask Pseudo-beta mask
""" """
# Initialize recycling embeddings # Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing until the final recycling layer # Disable activation checkpointing until the final recycling layer
self._disable_activation_checkpointing() self._disable_activation_checkpointing()
grad_vals = self._disable_grad()
# Main recycling loop # Main recycling loop
for cycle_no in range(self.config.no_cycles): for cycle_no in range(self.config.no_cycles):
...@@ -405,14 +406,17 @@ class AlphaFold(nn.Module): ...@@ -405,14 +406,17 @@ class AlphaFold(nn.Module):
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == self.config.no_cycles - 1) is_final_iter = (cycle_no == (self.config.no_cycles - 1))
if(self.training and is_final_iter): if(is_final_iter):
self._enable_activation_checkpointing() self._enable_activation_checkpointing()
with torch.set_grad_enabled(self.training and is_final_iter): self._enable_grad(grad_vals)
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev, # Run the next iteration of the model
) outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
)
# Run auxiliary heads
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
return outputs return outputs
...@@ -94,10 +94,8 @@ class MSAAttention(nn.Module): ...@@ -94,10 +94,8 @@ class MSAAttention(nn.Module):
n_seq, n_res = m.shape[-3:-1] n_seq, n_res = m.shape[-3:-1]
if(mask is None): if(mask is None):
# [*, N_seq, N_res] # [*, N_seq, N_res]
mask = torch.ones( mask = m.new_ones(
m.shape[:-3] + (n_seq, n_res), m.shape[:-3] + (n_seq, n_res),
device=m.device,
requires_grad=False
) )
# [*, N_seq, 1, 1, N_res] # [*, N_seq, 1, 1, N_res]
......
...@@ -70,7 +70,7 @@ class OuterProductMean(nn.Module): ...@@ -70,7 +70,7 @@ class OuterProductMean(nn.Module):
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
if(mask is None): if(mask is None):
mask = m.new_ones(m.shape[:-1], requires_grad=False) mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m] # [*, N_seq, N_res, C_m]
m = self.layer_norm(m) m = self.layer_norm(m)
......
...@@ -64,7 +64,7 @@ class PairTransition(nn.Module): ...@@ -64,7 +64,7 @@ class PairTransition(nn.Module):
""" """
# DISCREPANCY: DeepMind forgets to apply the mask in this module. # DISCREPANCY: DeepMind forgets to apply the mask in this module.
if(mask is None): if(mask is None):
mask = z.new_ones(z.shape[:-1], requires_grad=False) mask = z.new_ones(z.shape[:-1])
# [*, N_res, N_res, 1] # [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
......
...@@ -251,10 +251,10 @@ class Attention(nn.Module): ...@@ -251,10 +251,10 @@ class Attention(nn.Module):
permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K] permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K]
) )
norm = 1 / math.sqrt(self.c_hidden) # [1] norm = 1 / math.sqrt(self.c_hidden) # [1]
a *= norm a = a * norm
if(biases is not None): if(biases is not None):
for b in biases: for b in biases:
a += b a = a + b
a = self.softmax(a) a = self.softmax(a)
#print(torch.any(torch.isnan(a))) #print(torch.any(torch.isnan(a)))
...@@ -330,7 +330,7 @@ class GlobalAttention(nn.Module): ...@@ -330,7 +330,7 @@ class GlobalAttention(nn.Module):
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
) )
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias a = a + bias
a = self.softmax(a) a = self.softmax(a)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
......
...@@ -27,7 +27,7 @@ from openfold.np.residue_constants import ( ...@@ -27,7 +27,7 @@ from openfold.np.residue_constants import (
) )
from openfold.utils.affine_utils import T, quat_to_rot from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
stack_tensor_dicts, dict_multimap,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
) )
...@@ -337,10 +337,15 @@ class InvariantPointAttention(nn.Module): ...@@ -337,10 +337,15 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2) o = flatten_final_dims(o, 2)
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v] # [*, H, 3, N_res, P_v]
o_pt = torch.matmul( o_pt = torch.sum(
a.unsqueeze(-3), # [*, H, 1, N_res, N_res] (
permute_final_dims(v_pts, (1, 3, 0, 2)), # [*, H, 3, N_res, P_v] a[..., None, :, :, None] *
permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2
) )
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v, 3]
...@@ -702,7 +707,7 @@ class StructureModule(nn.Module): ...@@ -702,7 +707,7 @@ class StructureModule(nn.Module):
""" """
if(mask is None): if(mask is None):
# [*, N] # [*, N]
mask = s.new_ones(s.shape[:-1], requires_grad=False) mask = s.new_ones(s.shape[:-1])
# [*, N, C_s] # [*, N, C_s]
s = self.layer_norm_s(s) s = self.layer_norm_s(s)
...@@ -718,7 +723,7 @@ class StructureModule(nn.Module): ...@@ -718,7 +723,7 @@ class StructureModule(nn.Module):
t = T.identity(s.shape[:-1], s.dtype, s.device, self.training) t = T.identity(s.shape[:-1], s.dtype, s.device, self.training)
outputs = [] outputs = []
for l in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, t, mask) s = s + self.ipa(s, z, t, mask)
s = self.ipa_dropout(s) s = self.ipa_dropout(s)
...@@ -751,10 +756,10 @@ class StructureModule(nn.Module): ...@@ -751,10 +756,10 @@ class StructureModule(nn.Module):
outputs.append(preds) outputs.append(preds)
if(l < self.no_blocks - 1): if(i < (self.no_blocks - 1)):
t = t.stop_rot_gradient() t = t.stop_rot_gradient()
outputs = stack_tensor_dicts(outputs) outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s outputs["single"] = s
return outputs return outputs
...@@ -765,27 +770,23 @@ class StructureModule(nn.Module): ...@@ -765,27 +770,23 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
if(self.group_idx is None): if(self.group_idx is None):
self.group_idx = torch.tensor( self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
device=device, device=device,
requires_grad=False,
) )
if(self.atom_mask is None): if(self.atom_mask is None):
self.atom_mask = torch.tensor( self.atom_mask = torch.tensor(
restype_atom14_mask, restype_atom14_mask,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
if(self.lit_positions is None): if(self.lit_positions is None):
self.lit_positions = torch.tensor( self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
def torsion_angles_to_frames(self, t, alpha, f): def torsion_angles_to_frames(self, t, alpha, f):
...@@ -799,8 +800,6 @@ class StructureModule(nn.Module): ...@@ -799,8 +800,6 @@ class StructureModule(nn.Module):
f # [*, N] f # [*, N]
): ):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these
# arrays
self._init_residue_constants(t.rots.dtype, t.rots.device) self._init_residue_constants(t.rots.dtype, t.rots.device)
return _frames_and_literature_positions_to_atom14_pos( return _frames_and_literature_positions_to_atom14_pos(
t, t,
......
...@@ -73,10 +73,8 @@ class TriangleAttention(nn.Module): ...@@ -73,10 +73,8 @@ class TriangleAttention(nn.Module):
""" """
if(mask is None): if(mask is None):
# [*, I, J] # [*, I, J]
mask = torch.ones( mask = x.new_ones(
x.shape[:-1], x.shape[:-1],
device=x.device,
requires_grad=False,
) )
# Shape annotations assume self.starting. Else, I and J are flipped # Shape annotations assume self.starting. Else, I and J are flipped
......
...@@ -91,7 +91,7 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -91,7 +91,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
[*, N_res, N_res, C_z] output tensor [*, N_res, N_res, C_z] output tensor
""" """
if(mask is None): if(mask is None):
mask = z.new_ones(z.shape[:-1], requires_grad=False) mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
......
...@@ -163,7 +163,7 @@ class T: ...@@ -163,7 +163,7 @@ class T:
return trans return trans
@staticmethod @staticmethod
def identity(shape, dtype, device, requires_grad=False): def identity(shape, dtype, device, requires_grad=True):
return T( return T(
T.identity_rot(shape, dtype, device, requires_grad), T.identity_rot(shape, dtype, device, requires_grad),
T.identity_trans(shape, dtype, device, requires_grad), T.identity_trans(shape, dtype, device, requires_grad),
...@@ -191,11 +191,6 @@ class T: ...@@ -191,11 +191,6 @@ class T:
e0 = origin - p_neg_x_axis e0 = origin - p_neg_x_axis
e1 = p_xy_plane - origin e1 = p_xy_plane - origin
# Angle norming is very sensitive to floating point imprecisions
#float_type = e0.dtype
#e0 = e0.float()
#e1 = e1.float()
e0 = e0 / torch.sqrt(torch.sum(e0 ** 2, dim=-1, keepdims=True) + eps) e0 = e0 / torch.sqrt(torch.sum(e0 ** 2, dim=-1, keepdims=True) + eps)
e1 = e1 - e0 * torch.sum(e0 * e1, dim=-1, keepdims=True) e1 = e1 - e0 * torch.sum(e0 * e1, dim=-1, keepdims=True)
e1 = e1 / torch.sqrt(torch.sum(e1 ** 2, dim=-1, keepdims=True) + eps) e1 = e1 / torch.sqrt(torch.sum(e1 ** 2, dim=-1, keepdims=True) + eps)
...@@ -203,8 +198,6 @@ class T: ...@@ -203,8 +198,6 @@ class T:
rots = torch.stack([e0, e1, e2], dim=-1) rots = torch.stack([e0, e1, e2], dim=-1)
#rots = rots.type(float_type)
return T(rots, origin) return T(rots, origin)
@staticmethod @staticmethod
...@@ -221,7 +214,8 @@ class T: ...@@ -221,7 +214,8 @@ class T:
return T(rots, trans) return T(rots, trans)
def map_tensor_fn(self, fn): def map_tensor_fn(self, fn):
""" Apply a function that takes a tensor as its only argument to the """
Apply a function that takes a tensor as its only argument to the
rotations and translations, treating the final two/one rotations and translations, treating the final two/one
dimension(s), respectively, as batch dimensions. dimension(s), respectively, as batch dimensions.
...@@ -253,7 +247,7 @@ class T: ...@@ -253,7 +247,7 @@ class T:
n_xyz = n_xyz + translation n_xyz = n_xyz + translation
c_xyz = c_xyz + translation c_xyz = c_xyz + translation
c_x, c_y, c_z = [c_xyz[...,i] for i in range(3)] c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + c_x**2 + c_y**2) norm = torch.sqrt(eps + c_x**2 + c_y**2)
sin_c1 = -c_y / norm sin_c1 = -c_y / norm
cos_c1 = c_x / norm cos_c1 = c_x / norm
...@@ -278,7 +272,7 @@ class T: ...@@ -278,7 +272,7 @@ class T:
c1_rots[..., 2, 0] = -1 * sin_c2 c1_rots[..., 2, 0] = -1 * sin_c2
c1_rots[..., 2, 2] = cos_c2 c1_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rot_matrix, c1_rot_matrix) c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz) n_xyz = rot_vec_mul(c_rots, n_xyz)
_, n_y, n_z = [n_xyz[..., i] for i in range(3)] _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
......
...@@ -151,7 +151,7 @@ def atom14_to_atom37(atom14, batch): ...@@ -151,7 +151,7 @@ def atom14_to_atom37(atom14, batch):
no_batch_dims=len(atom14.shape[:-2]), no_batch_dims=len(atom14.shape[:-2]),
) )
atom37_data *= batch["atom37_atom_exists"][..., None] atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
return atom37_data return atom37_data
...@@ -288,7 +288,7 @@ def atom37_to_torsion_angles( ...@@ -288,7 +288,7 @@ def atom37_to_torsion_angles(
) )
torsion_angles_sin_cos = torsion_angles_sin_cos / denom torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos *= torch.tensor( torsion_angles_sin_cos = torsion_angles_sin_cos * torch.tensor(
[1., 1., -1., 1., 1., 1., 1.], device=aatype.device, [1., 1., -1., 1., 1., 1., 1.], device=aatype.device,
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
...@@ -335,11 +335,8 @@ def atom37_to_frames( ...@@ -335,11 +335,8 @@ def atom37_to_frames(
restype, chi_idx + 4, : restype, chi_idx + 4, :
] = names[1:] ] = names[1:]
restype_rigidgroup_mask = torch.zeros( restype_rigidgroup_mask = all_atom_mask.new_zeros(
(*aatype.shape[:-1], 21, 8), (*aatype.shape[:-1], 21, 8),
dtype=all_atom_mask.dtype,
device=aatype.device,
requires_grad=False
) )
restype_rigidgroup_mask[..., 0] = 1 restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1 restype_rigidgroup_mask[..., 3] = 1
...@@ -399,7 +396,7 @@ def atom37_to_frames( ...@@ -399,7 +396,7 @@ def atom37_to_frames(
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye( rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device, requires_grad=False 3, dtype=all_atom_mask.dtype, device=aatype.device
) )
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
...@@ -411,7 +408,7 @@ def atom37_to_frames( ...@@ -411,7 +408,7 @@ def atom37_to_frames(
*((1,) * batch_dims), 21, 8 *((1,) * batch_dims), 21, 8
) )
restype_rigidgroup_rots = torch.eye( restype_rigidgroup_rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device, requires_grad=False 3, dtype=all_atom_mask.dtype, device=aatype.device
) )
restype_rigidgroup_rots = torch.tile( restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots, restype_rigidgroup_rots,
...@@ -476,7 +473,7 @@ def build_template_angle_feat(angle_feats, template_aatype): ...@@ -476,7 +473,7 @@ def build_template_angle_feat(angle_feats, template_aatype):
return template_angle_feat return template_angle_feat
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8): def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8):
template_mask = batch["template_pseudo_beta_mask"] template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
...@@ -507,20 +504,30 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8 ...@@ -507,20 +504,30 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
) )
n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']] n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']]
affines = T.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :],
)
points = affines.get_trans()[..., None, :, :]
affine_vec = affines[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(
eps + torch.sum(affine_vec ** 2, dim=-1)
)
t_aa_masks = batch["template_all_atom_masks"] t_aa_masks = batch["template_all_atom_masks"]
template_mask = ( template_mask = (
t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
) )
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
unit_vector = template_mask_2d.new_zeros(*template_mask_2d.shape, 3) inv_distance_scalar = inv_distance_scalar * template_mask_2d
to_concat.append(unit_vector) unit_vector = (affine_vec * inv_distance_scalar[..., None])
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
act = torch.cat(to_concat, dim=-1) act = torch.cat(to_concat, dim=-1)
act = act * template_mask_2d[..., None]
act *= template_mask_2d[..., None]
return act return act
...@@ -594,7 +601,7 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None: ...@@ -594,7 +601,7 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
""" """
ambiguous_atoms = ( ambiguous_atoms = (
batch["atom14_gt_positions"].new_tensor( batch["atom14_gt_positions"].new_tensor(
rc.restype_atom14_ambiguous_atoms, requires_grad=False, rc.restype_atom14_ambiguous_atoms
) )
) )
...@@ -603,9 +610,7 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None: ...@@ -603,9 +610,7 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
# Swap pairs of ambiguous positions # Swap pairs of ambiguous positions
swap_idx = rc.restype_atom14_ambiguous_atoms_swap_idx swap_idx = rc.restype_atom14_ambiguous_atoms_swap_idx
swap_mat = np.eye(swap_idx.shape[-1])[swap_idx] # one-hot swap_idx swap_mat = np.eye(swap_idx.shape[-1])[swap_idx] # one-hot swap_idx
swap_mat = batch["atom14_gt_positions"].new_tensor( swap_mat = batch["atom14_gt_positions"].new_tensor(swap_mat)
swap_mat, requires_grad=False
)
swap_mat = swap_mat[batch["aatype"], ...] swap_mat = swap_mat[batch["aatype"], ...]
atom14_alt_gt_positions = ( atom14_alt_gt_positions = (
torch.sum( torch.sum(
......
...@@ -97,8 +97,8 @@ def compute_fape( ...@@ -97,8 +97,8 @@ def compute_fape(
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale normed_error = error_dist / length_scale
normed_error *= frames_mask[..., None] normed_error = normed_error * frames_mask[..., None]
normed_error *= positions_mask[..., None, :] normed_error = normed_error * positions_mask[..., None, :]
# FP16-friendly averaging. Roughly equivalent to: # FP16-friendly averaging. Roughly equivalent to:
# #
...@@ -291,7 +291,7 @@ def supervised_chi_loss( ...@@ -291,7 +291,7 @@ def supervised_chi_loss(
) )
loss = 0 loss = 0
loss += chi_weight * sq_chi_loss loss = loss + chi_weight * sq_chi_loss
angle_norm = torch.sqrt( angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps
...@@ -304,7 +304,7 @@ def supervised_chi_loss( ...@@ -304,7 +304,7 @@ def supervised_chi_loss(
seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3) seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
) )
loss += angle_norm_weight * angle_norm_loss loss = loss + angle_norm_weight * angle_norm_loss
return loss return loss
...@@ -380,7 +380,7 @@ def lddt_loss( ...@@ -380,7 +380,7 @@ def lddt_loss(
(dist_l1 < 2.0).type(dist_l1.dtype) + (dist_l1 < 2.0).type(dist_l1.dtype) +
(dist_l1 < 4.0).type(dist_l1.dtype) (dist_l1 < 4.0).type(dist_l1.dtype)
) )
score *= 0.25 score = score * 0.25
norm = 1. / (eps + torch.sum(dists_to_score, dim=-1)) norm = 1. / (eps + torch.sum(dists_to_score, dim=-1))
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1)) score = norm * (eps + torch.sum(dists_to_score * score, dim=-1))
...@@ -400,7 +400,7 @@ def lddt_loss( ...@@ -400,7 +400,7 @@ def lddt_loss(
(eps + torch.sum(all_atom_mask, dim=-1)) (eps + torch.sum(all_atom_mask, dim=-1))
) )
loss *= ( loss = loss * (
(resolution >= min_resolution) & (resolution >= min_resolution) &
(resolution <= max_resolution) (resolution <= max_resolution)
) )
...@@ -452,50 +452,60 @@ def distogram_loss( ...@@ -452,50 +452,60 @@ def distogram_loss(
return mean return mean
def tm_score( def tm_loss(
logits, logits,
t_pred, final_affine_tensor,
t_gt, backbone_affine_tensor,
mask, backbone_affine_mask,
resolution, resolution,
max_bin=31, max_bin=31,
no_bins=64, no_bins=64,
min_resolution: float = 0.1, min_resolution: float = 0.1,
max_resolution: float = 3.0, max_resolution: float = 3.0,
eps=1e-8 eps=1e-8,
**kwargs,
): ):
boundaries = torch.linspace( pred_affine = T.from_4x4(final_affine_tensor)
min=0, backbone_affine = T.from_4x4(backbone_affine_tensor)
max=max_bin,
steps=(no_bins - 1),
device=logits.device
)
boundaries = boundaries ** 2
def _points(affine): def _points(affine):
pts = affine.trans.unsqueeze(-3) pts = affine.get_trans()[..., None, :, :]
return affine.invert().apply(pts, addl_dims=1) return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum((_points(t_pred) - _points(t_gt)) ** 2, dim=-1) sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_affine)) ** 2,
dim=-1
)
sq_diff = sq_diff.detach() sq_diff = sq_diff.detach()
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
)
boundaries = boundaries ** 2
true_bins = torch.sum( true_bins = torch.sum(
sq_diff[..., None] > boundaries sq_diff[..., None] > boundaries, dim=-1
).float() )
errors = softmax_cross_entropy( errors = softmax_cross_entropy(
logits, logits,
torch.nn.functional.one_hot(true_bins, no_bins) torch.nn.functional.one_hot(true_bins, no_bins)
) )
square_mask = mask[..., None] * mask[..., None, :] square_mask = (
backbone_affine_mask[..., None] * backbone_affine_mask[..., None, :]
loss = (
torch.sum(loss, dim=(-1, -2)) /
(eps + torch.sum(square_mask, dim=(-1, -2)))
) )
loss *= ( loss = torch.sum(errors * square_mask, dim=-1)
scale = 0.1 # hack to help FP16 training along
denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss / scale
loss = loss * (
(resolution >= min_resolution) & (resolution >= min_resolution) &
(resolution <= max_resolution) (resolution <= max_resolution)
) )
...@@ -729,7 +739,7 @@ def between_residue_clash_loss( ...@@ -729,7 +739,7 @@ def between_residue_clash_loss(
# Mask out all the duplicate entries in the lower triangular matrix. # Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately. # are handled separately.
dists_mask *= ( dists_mask = dists_mask * (
residue_index[..., :, None, None, None] < residue_index[..., None, :, None, None] residue_index[..., :, None, None, None] < residue_index[..., None, :, None, None]
) )
...@@ -758,7 +768,7 @@ def between_residue_clash_loss( ...@@ -758,7 +768,7 @@ def between_residue_clash_loss(
c_one_hot[..., None, None, :, None] * c_one_hot[..., None, None, :, None] *
n_one_hot[..., None, None, None, :] n_one_hot[..., None, None, None, :]
) )
dists_mask *= (1. - c_n_bonds) dists_mask = dists_mask * (1. - c_n_bonds)
# Disulfide bridge between two cysteines is no clash. # Disulfide bridge between two cysteines is no clash.
cys = residue_constants.restype_name_to_atom14_names["CYS"] cys = residue_constants.restype_name_to_atom14_names["CYS"]
...@@ -773,7 +783,7 @@ def between_residue_clash_loss( ...@@ -773,7 +783,7 @@ def between_residue_clash_loss(
disulfide_bonds = ( disulfide_bonds = (
cys_sg_one_hot[..., None, None, :, None] * cys_sg_one_hot[..., None, None, :, None] *
cys_sg_one_hot[..., None, None, None, :]) cys_sg_one_hot[..., None, None, None, :])
dists_mask *= (1. - disulfide_bonds) dists_mask = dists_mask * (1. - disulfide_bonds)
# Compute the lower bound for the allowed distances. # Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14) # shape (N, N, 14, 14)
...@@ -1038,7 +1048,7 @@ def find_structural_violations_np( ...@@ -1038,7 +1048,7 @@ def find_structural_violations_np(
atom14_pred_positions: np.ndarray, atom14_pred_positions: np.ndarray,
config: ml_collections.ConfigDict config: ml_collections.ConfigDict
) -> Dict[str, np.ndarray]: ) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x, requires_grad=False) to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray) batch = tree_map(to_tensor, batch, np.ndarray)
atom14_pred_positions = to_tensor(atom14_pred_positions) atom14_pred_positions = to_tensor(atom14_pred_positions)
...@@ -1135,7 +1145,7 @@ def compute_violation_metrics_np( ...@@ -1135,7 +1145,7 @@ def compute_violation_metrics_np(
atom14_pred_positions: np.ndarray, atom14_pred_positions: np.ndarray,
violations: Dict[str, np.ndarray], violations: Dict[str, np.ndarray],
) -> Dict[str, np.ndarray]: ) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x, requires_grad=False) to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray) batch = tree_map(to_tensor, batch, np.ndarray)
atom14_pred_positions = to_tensor(atom14_pred_positions) atom14_pred_positions = to_tensor(atom14_pred_positions)
violations = tree_map(to_tensor, violations, np.ndarray) violations = tree_map(to_tensor, violations, np.ndarray)
...@@ -1285,10 +1295,11 @@ def experimentally_resolved_loss( ...@@ -1285,10 +1295,11 @@ def experimentally_resolved_loss(
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask) errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss_num = torch.sum(errors * atom37_atom_exists, dim=(-1, -2)) loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss_num / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = torch.sum(loss, dim=-1)
loss *= ( loss = loss * (
(resolution >= min_resolution) & (resolution >= min_resolution) &
(resolution <= max_resolution) (resolution <= max_resolution)
) )
...@@ -1307,11 +1318,13 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): ...@@ -1307,11 +1318,13 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
# torch.sum(errors * bert_mask, dim=(-1, -2)) / # torch.sum(errors * bert_mask, dim=(-1, -2)) /
# (eps + torch.sum(bert_mask, dim=(-1, -2))) # (eps + torch.sum(bert_mask, dim=(-1, -2)))
# ) # )
denom = eps + torch.sum(bert_mask, dim=(-1, -2))
loss = errors * bert_mask loss = errors * bert_mask
loss = torch.sum(loss, dim=-1) loss = torch.sum(loss, dim=-1)
scale = 0.1
denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
loss = loss / denom[..., None] loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1) loss = torch.sum(loss, dim=-1)
loss = loss / scale
return loss return loss
...@@ -1403,6 +1416,11 @@ class AlphaFoldLoss(nn.Module): ...@@ -1403,6 +1416,11 @@ class AlphaFoldLoss(nn.Module):
out["violation"], out["violation"],
**batch, **batch,
), ),
"tm":
lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
),
} }
cum_loss = 0 cum_loss = 0
......
...@@ -57,19 +57,6 @@ def dict_multimap(fn, dicts): ...@@ -57,19 +57,6 @@ def dict_multimap(fn, dicts):
return new_dict return new_dict
def stack_tensor_dicts(dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if(type(v) is dict):
new_dict[k] = stack_tensor_dicts(all_v)
else:
new_dict[k] = torch.stack(all_v)
return new_dict
def one_hot(x, v_bins): def one_hot(x, v_bins):
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins diffs = x[..., None] - reshaped_bins
...@@ -119,6 +106,7 @@ def tree_map(fn, tree, leaf_type): ...@@ -119,6 +106,7 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def chunk_layer( def chunk_layer(
layer: Callable, layer: Callable,
inputs: Dict[str, Any], inputs: Dict[str, Any],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment