Commit 1d47c1e7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Finish accommodating FP16. FAPE does not decrease

parent 3d9d977a
......@@ -12,17 +12,20 @@ def model_config(name):
c.model.template.enabled = False
return c
c_z = mlc.FieldReference(128)
c_m = mlc.FieldReference(256)
c_t = mlc.FieldReference(64)
c_e = mlc.FieldReference(64)
c_s = mlc.FieldReference(384)
blocks_per_ckpt = mlc.FieldReference(1)
chunk_size = mlc.FieldReference(4)#1280)
blocks_per_ckpt = mlc.FieldReference(1, field_type=int)
chunk_size = mlc.FieldReference(None, field_type=int)
aux_distogram_bins = mlc.FieldReference(64)
eps = 1e-4
inf = 1e4
config = mlc.ConfigDict({
"model": {
"c_z": c_z,
......@@ -45,7 +48,7 @@ config = mlc.ConfigDict({
"min_bin": 3.25,
"max_bin": 20.75,
"no_bins": 15,
"inf": 1e8,
"inf": inf,#1e8,
},
"template": {
"distogram": {
......@@ -74,6 +77,7 @@ config = mlc.ConfigDict({
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": inf,
},
"template_pointwise_attention": {
"c_t": c_t,
......@@ -83,8 +87,10 @@ config = mlc.ConfigDict({
"c_hidden": 16,
"no_heads": 4,
"chunk_size": chunk_size,
"inf": inf,#1e-9,
},
"eps": 1e-6,
"inf": inf,
"eps": eps,#1e-6,
"enabled": True,
"embed_angles": True,
},
......@@ -108,10 +114,10 @@ config = mlc.ConfigDict({
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e9,
"eps": 1e-10,
"inf": inf,#1e9,
"eps": eps,#1e-10,
},
"enabled": True,
"enabled": False,#True,
},
"evoformer_stack": {
"c_m": c_m,
......@@ -129,8 +135,8 @@ config = mlc.ConfigDict({
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e9,
"eps": 1e-10,
"inf": inf,#1e9,
"eps": eps,#1e-10,
},
"structure_module": {
"c_s": c_s,
......@@ -146,8 +152,8 @@ config = mlc.ConfigDict({
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 10,
"epsilon": 1e-12,
"inf": 1e5,
"epsilon": eps,#1e-12,
"inf": inf,#1e5,
},
"heads": {
"lddt": {
......@@ -186,11 +192,11 @@ config = mlc.ConfigDict({
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": 1e-6,
"weight": 0.3,
"eps": eps,#1e-6,
"weight": 0.,#0.3,
},
"experimentally_resolved": {
"eps": 1e-8,
"eps": eps,#1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.,
......@@ -206,6 +212,7 @@ config = mlc.ConfigDict({
"length_scale": 10.,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"lddt": {
......@@ -213,24 +220,25 @@ config = mlc.ConfigDict({
"max_resolution": 3.0,
"cutoff": 15.,
"no_bins": 50,
"eps": 1e-10,
"weight": 0.01,
"eps": eps,#1e-10,
"weight": 0.,#0.01,
},
"masked_msa": {
"eps": 1e-8,
"weight": 2.0,
"eps": eps,#1e-8,
"weight": 0.,#2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": 1e-6,
"weight": 1.0,
"eps": eps,#1e-6,
"weight": 0.,#1.0,
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"eps": 1e-6,
"eps": eps,#1e-6,
"weight": 0.,
},
"eps": eps,
},
})
......@@ -115,14 +115,22 @@ class AlphaFold(nn.Module):
batch,
)
#tensor_dtype = (
# single_template_feats["template_all_atom_masks"].dtype
#)
# Build template angle feats
angle_feats = atom37_to_torsion_angles(
single_template_feats["template_aatype"],
single_template_feats["template_all_atom_positions"],
single_template_feats["template_all_atom_masks"],
eps=1e-8
single_template_feats["template_all_atom_positions"],#.float(),
single_template_feats["template_all_atom_masks"],#.float(),
eps=self.config.template.eps,
)
#angle_feats = tensor_tree_map(
# lambda t: t.type(tensor_dtype), angle_feats
#)
template_angle_feat = build_template_angle_feat(
angle_feats,
single_template_feats["template_aatype"],
......@@ -134,6 +142,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram
)
......@@ -162,11 +171,11 @@ class AlphaFold(nn.Module):
template_mask=batch["template_mask"]
)
t = t * (torch.sum(batch["template_mask"]) > 0)
return {
"template_angle_embedding": a,
"template_angle_embedding": template_embeds["angle"],
"template_pair_embedding": t,
"torsion_angles_mask": angle_feats["torsion_angles_mask"],
"torsion_angles_mask": template_embeds["torsion_mask"],
}
def iteration(self, feats, m_1_prev, z_prev, x_prev):
......@@ -251,7 +260,7 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
......
......@@ -107,7 +107,6 @@ class MSAAttention(nn.Module):
bias = bias.expand(
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
)
biases = [bias]
if(self.pair_bias):
# [*, N_res, N_res, C_z]
......
......@@ -257,6 +257,8 @@ class Attention(nn.Module):
a += b
a = self.softmax(a)
#print(torch.any(torch.isnan(a)))
# [*, H, Q, C_hidden]
o = torch.matmul(
a,
......@@ -328,7 +330,7 @@ class GlobalAttention(nn.Module):
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a += bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
......
......@@ -50,6 +50,7 @@ class TemplatePointwiseAttention(nn.Module):
c_hidden,
no_heads,
chunk_size,
inf,
**kwargs
):
"""
......@@ -68,6 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self.c_hidden = c_hidden
self.no_heads = no_heads
self.chunk_size = chunk_size
self.inf = inf
self.mha = Attention(
self.c_z, self.c_t, self.c_t,
......@@ -89,11 +91,11 @@ class TemplatePointwiseAttention(nn.Module):
"""
if(template_mask is None):
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1,
# but not sure enough to remove it. It's nice to have, I guess.
# [*, N_templ] mask from the code. I'm pretty sure it's always just
# 1, but not sure enough to remove it. It's nice to have, I guess.
template_mask = t.new_ones(t.shape[:-3])
bias = (1e9 * (template_mask[..., None, None, None, None, :] - 1))
bias = (self.inf * (template_mask[..., None, None, None, None, :] - 1))
# [*, N_res, N_res, 1, C_z]
z = z.unsqueeze(-2)
......@@ -133,6 +135,8 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n,
dropout_rate,
chunk_size,
inf,
**kwargs,
):
super(TemplatePairStackBlock, self).__init__()
......@@ -143,6 +147,7 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate
self.chunk_size = chunk_size
self.inf = inf
self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate)
......@@ -152,12 +157,14 @@ class TemplatePairStackBlock(nn.Module):
self.c_hidden_tri_att,
self.no_heads,
chunk_size=chunk_size,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
self.c_t,
self.c_hidden_tri_att,
self.no_heads,
chunk_size=chunk_size,
inf=inf,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
......@@ -200,6 +207,7 @@ class TemplatePairStack(nn.Module):
dropout_rate,
blocks_per_ckpt,
chunk_size,
inf=1e9,
**kwargs,
):
"""
......@@ -237,6 +245,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
chunk_size=chunk_size,
inf=inf,
)
self.blocks.append(block)
......
......@@ -57,11 +57,17 @@ class T:
raise ValueError("Only one of rots and trans can be None")
elif(self.rots is None):
self.rots = T.identity_rot(
self.trans.shape[:-1], self.trans.dtype, self.trans.device
self.trans.shape[:-1],
self.trans.dtype,
self.trans.device,
self.trans.requires_grad,
)
elif(self.trans is None):
self.trans = T.identity_trans(
self.rots.shape[:-2], self.rots.dtype, self.rots.device
self.rots.shape[:-2],
self.rots.dtype,
self.rots.device,
self.rots.requires_grad
)
if(self.rots.shape[-2:] != (3, 3) or
......@@ -137,7 +143,7 @@ class T:
return T(rots, trans)
@staticmethod
def identity_rot(shape, dtype, device, requires_grad=False):
def identity_rot(shape, dtype, device, requires_grad):
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
)
......@@ -147,7 +153,7 @@ class T:
return rots
@staticmethod
def identity_trans(shape, dtype, device, requires_grad=False):
def identity_trans(shape, dtype, device, requires_grad):
trans = torch.zeros(
(*shape, 3),
dtype=dtype,
......@@ -182,20 +188,22 @@ class T:
@staticmethod
def from_3_points(p_neg_x_axis, origin, p_xy_plane, eps=1e-8):
v1 = origin - p_neg_x_axis
v2 = p_xy_plane - origin
e1 = v1 / torch.sqrt(torch.sum(v1 ** 2, dim=-1) + eps)[..., None]
u2 = v2 - e1 * (torch.einsum('...i,...i->...', v2, e1)[..., None])
e2 = u2 / torch.sqrt(torch.sum(u2 ** 2, dim=-1) + eps)[..., None]
e3 = torch.cross(e1, e2, dim=-1)
rots = torch.cat(
(
e1.unsqueeze(-1),
e2.unsqueeze(-1),
e3.unsqueeze(-1),
), dim=-1,
)
e0 = origin - p_neg_x_axis
e1 = p_xy_plane - origin
# Angle norming is very sensitive to floating point imprecisions
#float_type = e0.dtype
#e0 = e0.float()
#e1 = e1.float()
e0 = e0 / torch.sqrt(torch.sum(e0 ** 2, dim=-1, keepdims=True) + eps)
e1 = e1 - e0 * torch.sum(e0 * e1, dim=-1, keepdims=True)
e1 = e1 / torch.sqrt(torch.sum(e1 ** 2, dim=-1, keepdims=True) + eps)
e2 = torch.cross(e0, e1)
rots = torch.stack([e0, e1, e2], dim=-1)
#rots = rots.type(float_type)
return T(rots, origin)
......
......@@ -70,6 +70,9 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
#print(len(args))
#for a in args:
# print(a.requires_grad)
args = checkpoint(chunker(s, e), *args)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args)
......
......@@ -286,7 +286,7 @@ def atom37_to_torsion_angles(
torch.square(torsion_angles_sin_cos), dim=-1, keepdims=True
) + eps
)
torsion_angles_sin_cos /= denom
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos *= torch.tensor(
[1., 1., -1., 1., 1., 1., 1.], device=aatype.device,
......@@ -298,7 +298,7 @@ def atom37_to_torsion_angles(
mirror_torsion_angles = torch.cat(
[
aatype.new_ones(*aatype.shape, 3),
all_atom_mask.new_ones(*aatype.shape, 3),
1. - 2. * chi_is_ambiguous
], dim=-1
)
......
......@@ -80,7 +80,7 @@ def compute_fape(
positions_mask: torch.Tensor,
length_scale: float,
l1_clamp_distance: Optional[float] = None,
eps=1e-4
eps=1e-8
) -> torch.Tensor:
# [*, N_frames, N_pts, 3]
local_pred_pos = pred_frames.invert()[..., None].apply(
......@@ -100,12 +100,19 @@ def compute_fape(
normed_error *= frames_mask[..., None]
normed_error *= positions_mask[..., None, :]
norm_factor = (
torch.sum(frames_mask, dim=-1) *
torch.sum(positions_mask, dim=-1)
)
normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
# FP16-friendly averaging. Roughly equivalent to:
#
# norm_factor = (
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
return normed_error
......@@ -118,6 +125,7 @@ def backbone_loss(
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.,
loss_unit_distance: float = 10.,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
pred_aff = T.from_tensor(traj)
......@@ -132,6 +140,7 @@ def backbone_loss(
backbone_affine_mask[..., None, :],
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
eps=eps,
)
if(use_clamped_fape is not None):
......@@ -144,6 +153,7 @@ def backbone_loss(
backbone_affine_mask[..., None, :],
l1_clamp_distance=None,
length_scale=loss_unit_distance,
eps=eps,
)
fape_loss = (
......@@ -167,6 +177,7 @@ def sidechain_loss(
alt_naming_is_better: torch.Tensor,
clamp_distance: float = 10.,
length_scale: float = 10.,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
renamed_gt_frames = (
......@@ -200,7 +211,7 @@ def sidechain_loss(
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(
*batch_dims, -1
)
fape = compute_fape(
sidechain_frames,
renamed_gt_frames,
......@@ -210,6 +221,7 @@ def sidechain_loss(
renamed_atom14_gt_exists,
l1_clamp_distance=clamp_distance,
length_scale=length_scale,
eps=eps,
)
return fape
......@@ -428,12 +440,16 @@ def distogram_loss(
square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
mean = (
torch.sum(errors * square_mask, dim=(-1, -2)) /
(eps + torch.sum(square_mask, dim=(-1, -2)))
)
# FP16-friendly sum. Equivalent to:
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
# (eps + torch.sum(square_mask, dim=(-1, -2))))
denom = eps + torch.sum(square_mask, dim=(-1, -2))
mean = errors * square_mask
mean = torch.sum(mean, dim=-1)
mean = mean / denom[..., None]
mean = torch.sum(mean, dim=-1)
return mean
return mean
def tm_score(
......@@ -1285,10 +1301,18 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
logits,
torch.nn.functional.one_hot(true_msa, num_classes=23)
)
loss = (
torch.sum(errors * bert_mask, dim=(-1, -2)) /
(eps + torch.sum(bert_mask, dim=(-1, -2)))
)
# FP16-friendly averaging. Equivalent to:
# loss = (
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
# (eps + torch.sum(bert_mask, dim=(-1, -2)))
# )
denom = eps + torch.sum(bert_mask, dim=(-1, -2))
loss = errors * bert_mask
loss = torch.sum(loss, dim=-1)
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
return loss
......@@ -1298,9 +1322,7 @@ class AlphaFoldLoss(nn.Module):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch):
cum_loss = 0
def forward(self, out, batch):
if("violation" not in out.keys() and self.config.violation.weight):
out["violation"] = find_structural_violations(
batch,
......@@ -1331,6 +1353,7 @@ class AlphaFoldLoss(nn.Module):
if("chi_angles_sin_cos" not in batch.keys()):
batch.update(feats.atom37_to_torsion_angles(
**batch,
eps=self.config.eps,
))
# TODO: Verify that this is correct
......@@ -1382,12 +1405,14 @@ class AlphaFoldLoss(nn.Module):
),
}
cum_loss = 0
for k,loss_fn in loss_fns.items():
weight = self.config[k].weight
if(weight):
print(k)
loss = loss_fn()
#print(k)
#print(loss)
cum_loss += weight * loss
print(weight * loss)
cum_loss = cum_loss + weight * loss
print(cum_loss)
return cum_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment