"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "b8120504735bcc6f981fdddcce044026bd155e0e"
Commit eb49136d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Finish purge of in-place ops, get grads working, add TM

parent 1d47c1e7
......@@ -23,8 +23,8 @@ blocks_per_ckpt = mlc.FieldReference(1, field_type=int)
chunk_size = mlc.FieldReference(None, field_type=int)
aux_distogram_bins = mlc.FieldReference(64)
eps = 1e-4
inf = 1e4
eps = 1e-8
inf = 1e8
config = mlc.ConfigDict({
"model": {
......@@ -33,7 +33,7 @@ config = mlc.ConfigDict({
"c_t": c_t,
"c_e": c_e,
"c_s": c_s,
"no_cycles": 4,
"no_cycles": 2,#4,
"_mask_trans": False,
"input_embedder": {
"tf_dim": 22,
......@@ -117,7 +117,7 @@ config = mlc.ConfigDict({
"inf": inf,#1e9,
"eps": eps,#1e-10,
},
"enabled": False,#True,
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
......@@ -147,7 +147,7 @@ config = mlc.ConfigDict({
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_blocks": 2,#8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
......@@ -165,10 +165,10 @@ config = mlc.ConfigDict({
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm_score": {
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": False,
"enabled": True,
},
"masked_msa": {
"c_m": c_m,
......@@ -239,6 +239,14 @@ config = mlc.ConfigDict({
"eps": eps,#1e-6,
"weight": 0.,
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps,#1e-8,
"weight": 1.0,
},
"eps": eps,
},
})
......@@ -53,7 +53,7 @@ class Dropout(nn.Module):
if(self.batch_dim is not None):
for bd in self.batch_dim:
shape[bd] = 1
mask = x.new_ones(shape, requires_grad=False)
mask = x.new_ones(shape)
mask = self.dropout(mask)
x = x * mask
return x
......
......@@ -194,7 +194,6 @@ class RecyclingEmbedder(nn.Module):
self.max_bin,
self.no_bins,
dtype=x.dtype,
requires_grad=False,
device=x.device
)
......
......@@ -40,9 +40,9 @@ class AuxiliaryHeads(nn.Module):
**config["experimentally_resolved"],
)
if(config.tm_score.enabled):
self.tm_score = TMScoreHead(
**config["tm_score"],
if(config.tm.enabled):
self.tm = TMScoreHead(
**config.tm,
)
self.config = config
......@@ -68,9 +68,9 @@ class AuxiliaryHeads(nn.Module):
experimentally_resolved_logits
)
if(self.config.tm_score.enabled):
tm_score_logits = self.tm_score(outputs["pair"])
aux_out["tm_score_logits"] = tm_score_logits
if(self.config.tm.enabled):
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
return aux_out
......
......@@ -115,10 +115,6 @@ class AlphaFold(nn.Module):
batch,
)
#tensor_dtype = (
# single_template_feats["template_all_atom_masks"].dtype
#)
# Build template angle feats
angle_feats = atom37_to_torsion_angles(
single_template_feats["template_aatype"],
......@@ -127,10 +123,6 @@ class AlphaFold(nn.Module):
eps=self.config.template.eps,
)
#angle_feats = tensor_tree_map(
# lambda t: t.type(tensor_dtype), angle_feats
#)
template_angle_feat = build_template_angle_feat(
angle_feats,
single_template_feats["template_aatype"],
......@@ -211,19 +203,16 @@ class AlphaFold(nn.Module):
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.c_m),
requires_grad=False,
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.c_z),
requires_grad=False,
)
# [*, N, 3]
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev = pseudo_beta_fn(
......@@ -241,7 +230,7 @@ class AlphaFold(nn.Module):
)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
m[..., 0, :, :] = m[..., 0, :, :] + m_1_prev_emb
# [*, N, N, C_z]
z = z + z_prev_emb
......@@ -312,6 +301,7 @@ class AlphaFold(nn.Module):
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
# Save embeddings for use during the next recycling iteration
......@@ -342,6 +332,16 @@ class AlphaFold(nn.Module):
self.config.extra_msa.extra_msa_stack.blocks_per_ckpt
)
def _disable_grad(self):
vals = [p.requires_grad for p in self.parameters()]
for p in self.parameters():
p.requires_grad_(False)
return vals
def _enable_grad(self, vals):
for p, v in zip(self.parameters(), vals):
p.requires_grad_(v)
def forward(self, batch):
"""
Args:
......@@ -391,12 +391,13 @@ class AlphaFold(nn.Module):
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
"""
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing until the final recycling layer
self._disable_activation_checkpointing()
grad_vals = self._disable_grad()
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
......@@ -405,14 +406,17 @@ class AlphaFold(nn.Module):
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == self.config.no_cycles - 1)
if(self.training and is_final_iter):
is_final_iter = (cycle_no == (self.config.no_cycles - 1))
if(is_final_iter):
self._enable_activation_checkpointing()
with torch.set_grad_enabled(self.training and is_final_iter):
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
)
self._enable_grad(grad_vals)
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
)
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
return outputs
......@@ -94,10 +94,8 @@ class MSAAttention(nn.Module):
n_seq, n_res = m.shape[-3:-1]
if(mask is None):
# [*, N_seq, N_res]
mask = torch.ones(
mask = m.new_ones(
m.shape[:-3] + (n_seq, n_res),
device=m.device,
requires_grad=False
)
# [*, N_seq, 1, 1, N_res]
......
......@@ -70,7 +70,7 @@ class OuterProductMean(nn.Module):
[*, N_res, N_res, C_z] pair embedding update
"""
if(mask is None):
mask = m.new_ones(m.shape[:-1], requires_grad=False)
mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m]
m = self.layer_norm(m)
......
......@@ -64,7 +64,7 @@ class PairTransition(nn.Module):
"""
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
if(mask is None):
mask = z.new_ones(z.shape[:-1], requires_grad=False)
mask = z.new_ones(z.shape[:-1])
# [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1)
......
......@@ -251,10 +251,10 @@ class Attention(nn.Module):
permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K]
)
norm = 1 / math.sqrt(self.c_hidden) # [1]
a *= norm
a = a * norm
if(biases is not None):
for b in biases:
a += b
a = a + b
a = self.softmax(a)
#print(torch.any(torch.isnan(a)))
......@@ -330,7 +330,7 @@ class GlobalAttention(nn.Module):
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = a + bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
......
......@@ -27,7 +27,7 @@ from openfold.np.residue_constants import (
)
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.tensor_utils import (
stack_tensor_dicts,
dict_multimap,
permute_final_dims,
flatten_final_dims,
)
......@@ -337,10 +337,15 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v]
o_pt = torch.matmul(
a.unsqueeze(-3), # [*, H, 1, N_res, N_res]
permute_final_dims(v_pts, (1, 3, 0, 2)), # [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None] *
permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2
)
# [*, N_res, H, P_v, 3]
......@@ -702,7 +707,7 @@ class StructureModule(nn.Module):
"""
if(mask is None):
# [*, N]
mask = s.new_ones(s.shape[:-1], requires_grad=False)
mask = s.new_ones(s.shape[:-1])
# [*, N, C_s]
s = self.layer_norm_s(s)
......@@ -718,7 +723,7 @@ class StructureModule(nn.Module):
t = T.identity(s.shape[:-1], s.dtype, s.device, self.training)
outputs = []
for l in range(self.no_blocks):
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, t, mask)
s = self.ipa_dropout(s)
......@@ -751,10 +756,10 @@ class StructureModule(nn.Module):
outputs.append(preds)
if(l < self.no_blocks - 1):
if(i < (self.no_blocks - 1)):
t = t.stop_rot_gradient()
outputs = stack_tensor_dicts(outputs)
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
return outputs
......@@ -765,27 +770,23 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if(self.group_idx is None):
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
)
if(self.atom_mask is None):
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if(self.lit_positions is None):
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, t, alpha, f):
......@@ -799,8 +800,6 @@ class StructureModule(nn.Module):
f # [*, N]
):
# Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these
# arrays
self._init_residue_constants(t.rots.dtype, t.rots.device)
return _frames_and_literature_positions_to_atom14_pos(
t,
......
......@@ -73,10 +73,8 @@ class TriangleAttention(nn.Module):
"""
if(mask is None):
# [*, I, J]
mask = torch.ones(
mask = x.new_ones(
x.shape[:-1],
device=x.device,
requires_grad=False,
)
# Shape annotations assume self.starting. Else, I and J are flipped
......
......@@ -91,7 +91,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
[*, N_res, N_res, C_z] output tensor
"""
if(mask is None):
mask = z.new_ones(z.shape[:-1], requires_grad=False)
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
......
......@@ -163,7 +163,7 @@ class T:
return trans
@staticmethod
def identity(shape, dtype, device, requires_grad=False):
def identity(shape, dtype, device, requires_grad=True):
return T(
T.identity_rot(shape, dtype, device, requires_grad),
T.identity_trans(shape, dtype, device, requires_grad),
......@@ -191,11 +191,6 @@ class T:
e0 = origin - p_neg_x_axis
e1 = p_xy_plane - origin
# Angle norming is very sensitive to floating point imprecisions
#float_type = e0.dtype
#e0 = e0.float()
#e1 = e1.float()
e0 = e0 / torch.sqrt(torch.sum(e0 ** 2, dim=-1, keepdims=True) + eps)
e1 = e1 - e0 * torch.sum(e0 * e1, dim=-1, keepdims=True)
e1 = e1 / torch.sqrt(torch.sum(e1 ** 2, dim=-1, keepdims=True) + eps)
......@@ -203,8 +198,6 @@ class T:
rots = torch.stack([e0, e1, e2], dim=-1)
#rots = rots.type(float_type)
return T(rots, origin)
@staticmethod
......@@ -221,7 +214,8 @@ class T:
return T(rots, trans)
def map_tensor_fn(self, fn):
""" Apply a function that takes a tensor as its only argument to the
"""
Apply a function that takes a tensor as its only argument to the
rotations and translations, treating the final two/one
dimension(s), respectively, as batch dimensions.
......@@ -253,7 +247,7 @@ class T:
n_xyz = n_xyz + translation
c_xyz = c_xyz + translation
c_x, c_y, c_z = [c_xyz[...,i] for i in range(3)]
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + c_x**2 + c_y**2)
sin_c1 = -c_y / norm
cos_c1 = c_x / norm
......@@ -278,7 +272,7 @@ class T:
c1_rots[..., 2, 0] = -1 * sin_c2
c1_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rot_matrix, c1_rot_matrix)
c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz)
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
......
......@@ -151,7 +151,7 @@ def atom14_to_atom37(atom14, batch):
no_batch_dims=len(atom14.shape[:-2]),
)
atom37_data *= batch["atom37_atom_exists"][..., None]
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
return atom37_data
......@@ -288,7 +288,7 @@ def atom37_to_torsion_angles(
)
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos *= torch.tensor(
torsion_angles_sin_cos = torsion_angles_sin_cos * torch.tensor(
[1., 1., -1., 1., 1., 1., 1.], device=aatype.device,
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
......@@ -335,11 +335,8 @@ def atom37_to_frames(
restype, chi_idx + 4, :
] = names[1:]
restype_rigidgroup_mask = torch.zeros(
restype_rigidgroup_mask = all_atom_mask.new_zeros(
(*aatype.shape[:-1], 21, 8),
dtype=all_atom_mask.dtype,
device=aatype.device,
requires_grad=False
)
restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1
......@@ -399,7 +396,7 @@ def atom37_to_frames(
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device, requires_grad=False
3, dtype=all_atom_mask.dtype, device=aatype.device
)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
......@@ -411,7 +408,7 @@ def atom37_to_frames(
*((1,) * batch_dims), 21, 8
)
restype_rigidgroup_rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device, requires_grad=False
3, dtype=all_atom_mask.dtype, device=aatype.device
)
restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots,
......@@ -476,7 +473,7 @@ def build_template_angle_feat(angle_feats, template_aatype):
return template_angle_feat
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8):
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
......@@ -507,20 +504,30 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
)
n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']]
affines = T.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :],
)
points = affines.get_trans()[..., None, :, :]
affine_vec = affines[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(
eps + torch.sum(affine_vec ** 2, dim=-1)
)
t_aa_masks = batch["template_all_atom_masks"]
template_mask = (
t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
)
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
unit_vector = template_mask_2d.new_zeros(*template_mask_2d.shape, 3)
to_concat.append(unit_vector)
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = (affine_vec * inv_distance_scalar[..., None])
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
act = torch.cat(to_concat, dim=-1)
act *= template_mask_2d[..., None]
act = act * template_mask_2d[..., None]
return act
......@@ -594,7 +601,7 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"""
ambiguous_atoms = (
batch["atom14_gt_positions"].new_tensor(
rc.restype_atom14_ambiguous_atoms, requires_grad=False,
rc.restype_atom14_ambiguous_atoms
)
)
......@@ -603,9 +610,7 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
# Swap pairs of ambiguous positions
swap_idx = rc.restype_atom14_ambiguous_atoms_swap_idx
swap_mat = np.eye(swap_idx.shape[-1])[swap_idx] # one-hot swap_idx
swap_mat = batch["atom14_gt_positions"].new_tensor(
swap_mat, requires_grad=False
)
swap_mat = batch["atom14_gt_positions"].new_tensor(swap_mat)
swap_mat = swap_mat[batch["aatype"], ...]
atom14_alt_gt_positions = (
torch.sum(
......
......@@ -97,8 +97,8 @@ def compute_fape(
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale
normed_error *= frames_mask[..., None]
normed_error *= positions_mask[..., None, :]
normed_error = normed_error * frames_mask[..., None]
normed_error = normed_error * positions_mask[..., None, :]
# FP16-friendly averaging. Roughly equivalent to:
#
......@@ -291,7 +291,7 @@ def supervised_chi_loss(
)
loss = 0
loss += chi_weight * sq_chi_loss
loss = loss + chi_weight * sq_chi_loss
angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps
......@@ -304,7 +304,7 @@ def supervised_chi_loss(
seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
)
loss += angle_norm_weight * angle_norm_loss
loss = loss + angle_norm_weight * angle_norm_loss
return loss
......@@ -380,7 +380,7 @@ def lddt_loss(
(dist_l1 < 2.0).type(dist_l1.dtype) +
(dist_l1 < 4.0).type(dist_l1.dtype)
)
score *= 0.25
score = score * 0.25
norm = 1. / (eps + torch.sum(dists_to_score, dim=-1))
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1))
......@@ -400,7 +400,7 @@ def lddt_loss(
(eps + torch.sum(all_atom_mask, dim=-1))
)
loss *= (
loss = loss * (
(resolution >= min_resolution) &
(resolution <= max_resolution)
)
......@@ -452,50 +452,60 @@ def distogram_loss(
return mean
def tm_score(
def tm_loss(
logits,
t_pred,
t_gt,
mask,
final_affine_tensor,
backbone_affine_tensor,
backbone_affine_mask,
resolution,
max_bin=31,
no_bins=64,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps=1e-8
eps=1e-8,
**kwargs,
):
boundaries = torch.linspace(
min=0,
max=max_bin,
steps=(no_bins - 1),
device=logits.device
)
boundaries = boundaries ** 2
pred_affine = T.from_4x4(final_affine_tensor)
backbone_affine = T.from_4x4(backbone_affine_tensor)
def _points(affine):
pts = affine.trans.unsqueeze(-3)
return affine.invert().apply(pts, addl_dims=1)
pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum((_points(t_pred) - _points(t_gt)) ** 2, dim=-1)
sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_affine)) ** 2,
dim=-1
)
sq_diff = sq_diff.detach()
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
)
boundaries = boundaries ** 2
true_bins = torch.sum(
sq_diff[..., None] > boundaries
).float()
sq_diff[..., None] > boundaries, dim=-1
)
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_bins, no_bins)
)
square_mask = mask[..., None] * mask[..., None, :]
loss = (
torch.sum(loss, dim=(-1, -2)) /
(eps + torch.sum(square_mask, dim=(-1, -2)))
square_mask = (
backbone_affine_mask[..., None] * backbone_affine_mask[..., None, :]
)
loss *= (
loss = torch.sum(errors * square_mask, dim=-1)
scale = 0.1 # hack to help FP16 training along
denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss / scale
loss = loss * (
(resolution >= min_resolution) &
(resolution <= max_resolution)
)
......@@ -729,7 +739,7 @@ def between_residue_clash_loss(
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask *= (
dists_mask = dists_mask * (
residue_index[..., :, None, None, None] < residue_index[..., None, :, None, None]
)
......@@ -758,7 +768,7 @@ def between_residue_clash_loss(
c_one_hot[..., None, None, :, None] *
n_one_hot[..., None, None, None, :]
)
dists_mask *= (1. - c_n_bonds)
dists_mask = dists_mask * (1. - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys = residue_constants.restype_name_to_atom14_names["CYS"]
......@@ -773,7 +783,7 @@ def between_residue_clash_loss(
disulfide_bonds = (
cys_sg_one_hot[..., None, None, :, None] *
cys_sg_one_hot[..., None, None, None, :])
dists_mask *= (1. - disulfide_bonds)
dists_mask = dists_mask * (1. - disulfide_bonds)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
......@@ -1038,7 +1048,7 @@ def find_structural_violations_np(
atom14_pred_positions: np.ndarray,
config: ml_collections.ConfigDict
) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x, requires_grad=False)
to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray)
atom14_pred_positions = to_tensor(atom14_pred_positions)
......@@ -1135,7 +1145,7 @@ def compute_violation_metrics_np(
atom14_pred_positions: np.ndarray,
violations: Dict[str, np.ndarray],
) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x, requires_grad=False)
to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray)
atom14_pred_positions = to_tensor(atom14_pred_positions)
violations = tree_map(to_tensor, violations, np.ndarray)
......@@ -1285,10 +1295,11 @@ def experimentally_resolved_loss(
**kwargs,
) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss_num = torch.sum(errors * atom37_atom_exists, dim=(-1, -2))
loss = loss_num / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = torch.sum(loss, dim=-1)
loss *= (
loss = loss * (
(resolution >= min_resolution) &
(resolution <= max_resolution)
)
......@@ -1307,11 +1318,13 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
# (eps + torch.sum(bert_mask, dim=(-1, -2)))
# )
denom = eps + torch.sum(bert_mask, dim=(-1, -2))
loss = errors * bert_mask
loss = torch.sum(loss, dim=-1)
scale = 0.1
denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss / scale
return loss
......@@ -1403,6 +1416,11 @@ class AlphaFoldLoss(nn.Module):
out["violation"],
**batch,
),
"tm":
lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
),
}
cum_loss = 0
......
......@@ -57,19 +57,6 @@ def dict_multimap(fn, dicts):
return new_dict
def stack_tensor_dicts(dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if(type(v) is dict):
new_dict[k] = stack_tensor_dicts(all_v)
else:
new_dict[k] = torch.stack(all_v)
return new_dict
def one_hot(x, v_bins):
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
......@@ -119,6 +106,7 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment