Commit 1d47c1e7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Finish accommodating FP16. FAPE does not decrease

parent 3d9d977a
...@@ -12,17 +12,20 @@ def model_config(name): ...@@ -12,17 +12,20 @@ def model_config(name):
c.model.template.enabled = False c.model.template.enabled = False
return c return c
c_z = mlc.FieldReference(128) c_z = mlc.FieldReference(128)
c_m = mlc.FieldReference(256) c_m = mlc.FieldReference(256)
c_t = mlc.FieldReference(64) c_t = mlc.FieldReference(64)
c_e = mlc.FieldReference(64) c_e = mlc.FieldReference(64)
c_s = mlc.FieldReference(384) c_s = mlc.FieldReference(384)
blocks_per_ckpt = mlc.FieldReference(1) blocks_per_ckpt = mlc.FieldReference(1, field_type=int)
chunk_size = mlc.FieldReference(4)#1280) chunk_size = mlc.FieldReference(None, field_type=int)
aux_distogram_bins = mlc.FieldReference(64) aux_distogram_bins = mlc.FieldReference(64)
eps = 1e-4
inf = 1e4
config = mlc.ConfigDict({ config = mlc.ConfigDict({
"model": { "model": {
"c_z": c_z, "c_z": c_z,
...@@ -45,7 +48,7 @@ config = mlc.ConfigDict({ ...@@ -45,7 +48,7 @@ config = mlc.ConfigDict({
"min_bin": 3.25, "min_bin": 3.25,
"max_bin": 20.75, "max_bin": 20.75,
"no_bins": 15, "no_bins": 15,
"inf": 1e8, "inf": inf,#1e8,
}, },
"template": { "template": {
"distogram": { "distogram": {
...@@ -74,6 +77,7 @@ config = mlc.ConfigDict({ ...@@ -74,6 +77,7 @@ config = mlc.ConfigDict({
"dropout_rate": 0.25, "dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
"inf": inf,
}, },
"template_pointwise_attention": { "template_pointwise_attention": {
"c_t": c_t, "c_t": c_t,
...@@ -83,8 +87,10 @@ config = mlc.ConfigDict({ ...@@ -83,8 +87,10 @@ config = mlc.ConfigDict({
"c_hidden": 16, "c_hidden": 16,
"no_heads": 4, "no_heads": 4,
"chunk_size": chunk_size, "chunk_size": chunk_size,
"inf": inf,#1e-9,
}, },
"eps": 1e-6, "inf": inf,
"eps": eps,#1e-6,
"enabled": True, "enabled": True,
"embed_angles": True, "embed_angles": True,
}, },
...@@ -108,10 +114,10 @@ config = mlc.ConfigDict({ ...@@ -108,10 +114,10 @@ config = mlc.ConfigDict({
"pair_dropout": 0.25, "pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
"inf": 1e9, "inf": inf,#1e9,
"eps": 1e-10, "eps": eps,#1e-10,
}, },
"enabled": True, "enabled": False,#True,
}, },
"evoformer_stack": { "evoformer_stack": {
"c_m": c_m, "c_m": c_m,
...@@ -129,8 +135,8 @@ config = mlc.ConfigDict({ ...@@ -129,8 +135,8 @@ config = mlc.ConfigDict({
"pair_dropout": 0.25, "pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
"inf": 1e9, "inf": inf,#1e9,
"eps": 1e-10, "eps": eps,#1e-10,
}, },
"structure_module": { "structure_module": {
"c_s": c_s, "c_s": c_s,
...@@ -146,8 +152,8 @@ config = mlc.ConfigDict({ ...@@ -146,8 +152,8 @@ config = mlc.ConfigDict({
"no_resnet_blocks": 2, "no_resnet_blocks": 2,
"no_angles": 7, "no_angles": 7,
"trans_scale_factor": 10, "trans_scale_factor": 10,
"epsilon": 1e-12, "epsilon": eps,#1e-12,
"inf": 1e5, "inf": inf,#1e5,
}, },
"heads": { "heads": {
"lddt": { "lddt": {
...@@ -186,11 +192,11 @@ config = mlc.ConfigDict({ ...@@ -186,11 +192,11 @@ config = mlc.ConfigDict({
"min_bin": 2.3125, "min_bin": 2.3125,
"max_bin": 21.6875, "max_bin": 21.6875,
"no_bins": 64, "no_bins": 64,
"eps": 1e-6, "eps": eps,#1e-6,
"weight": 0.3, "weight": 0.,#0.3,
}, },
"experimentally_resolved": { "experimentally_resolved": {
"eps": 1e-8, "eps": eps,#1e-8,
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"weight": 0., "weight": 0.,
...@@ -206,6 +212,7 @@ config = mlc.ConfigDict({ ...@@ -206,6 +212,7 @@ config = mlc.ConfigDict({
"length_scale": 10., "length_scale": 10.,
"weight": 0.5, "weight": 0.5,
}, },
"eps": 1e-4,
"weight": 1.0, "weight": 1.0,
}, },
"lddt": { "lddt": {
...@@ -213,24 +220,25 @@ config = mlc.ConfigDict({ ...@@ -213,24 +220,25 @@ config = mlc.ConfigDict({
"max_resolution": 3.0, "max_resolution": 3.0,
"cutoff": 15., "cutoff": 15.,
"no_bins": 50, "no_bins": 50,
"eps": 1e-10, "eps": eps,#1e-10,
"weight": 0.01, "weight": 0.,#0.01,
}, },
"masked_msa": { "masked_msa": {
"eps": 1e-8, "eps": eps,#1e-8,
"weight": 2.0, "weight": 0.,#2.0,
}, },
"supervised_chi": { "supervised_chi": {
"chi_weight": 0.5, "chi_weight": 0.5,
"angle_norm_weight": 0.01, "angle_norm_weight": 0.01,
"eps": 1e-6, "eps": eps,#1e-6,
"weight": 1.0, "weight": 0.,#1.0,
}, },
"violation": { "violation": {
"violation_tolerance_factor": 12.0, "violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5, "clash_overlap_tolerance": 1.5,
"eps": 1e-6, "eps": eps,#1e-6,
"weight": 0., "weight": 0.,
}, },
"eps": eps,
}, },
}) })
...@@ -115,14 +115,22 @@ class AlphaFold(nn.Module): ...@@ -115,14 +115,22 @@ class AlphaFold(nn.Module):
batch, batch,
) )
#tensor_dtype = (
# single_template_feats["template_all_atom_masks"].dtype
#)
# Build template angle feats # Build template angle feats
angle_feats = atom37_to_torsion_angles( angle_feats = atom37_to_torsion_angles(
single_template_feats["template_aatype"], single_template_feats["template_aatype"],
single_template_feats["template_all_atom_positions"], single_template_feats["template_all_atom_positions"],#.float(),
single_template_feats["template_all_atom_masks"], single_template_feats["template_all_atom_masks"],#.float(),
eps=1e-8 eps=self.config.template.eps,
) )
#angle_feats = tensor_tree_map(
# lambda t: t.type(tensor_dtype), angle_feats
#)
template_angle_feat = build_template_angle_feat( template_angle_feat = build_template_angle_feat(
angle_feats, angle_feats,
single_template_feats["template_aatype"], single_template_feats["template_aatype"],
...@@ -134,6 +142,7 @@ class AlphaFold(nn.Module): ...@@ -134,6 +142,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_t] # [*, S_t, N, N, C_t]
t = build_template_pair_feat( t = build_template_pair_feat(
single_template_feats, single_template_feats,
inf=self.config.template.inf,
eps=self.config.template.eps, eps=self.config.template.eps,
**self.config.template.distogram **self.config.template.distogram
) )
...@@ -162,11 +171,11 @@ class AlphaFold(nn.Module): ...@@ -162,11 +171,11 @@ class AlphaFold(nn.Module):
template_mask=batch["template_mask"] template_mask=batch["template_mask"]
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
return { return {
"template_angle_embedding": a, "template_angle_embedding": template_embeds["angle"],
"template_pair_embedding": t, "template_pair_embedding": t,
"torsion_angles_mask": angle_feats["torsion_angles_mask"], "torsion_angles_mask": template_embeds["torsion_mask"],
} }
def iteration(self, feats, m_1_prev, z_prev, x_prev): def iteration(self, feats, m_1_prev, z_prev, x_prev):
...@@ -251,7 +260,7 @@ class AlphaFold(nn.Module): ...@@ -251,7 +260,7 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles): if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
......
...@@ -107,7 +107,6 @@ class MSAAttention(nn.Module): ...@@ -107,7 +107,6 @@ class MSAAttention(nn.Module):
bias = bias.expand( bias = bias.expand(
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
) )
biases = [bias] biases = [bias]
if(self.pair_bias): if(self.pair_bias):
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
......
...@@ -257,6 +257,8 @@ class Attention(nn.Module): ...@@ -257,6 +257,8 @@ class Attention(nn.Module):
a += b a += b
a = self.softmax(a) a = self.softmax(a)
#print(torch.any(torch.isnan(a)))
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
o = torch.matmul( o = torch.matmul(
a, a,
...@@ -328,7 +330,7 @@ class GlobalAttention(nn.Module): ...@@ -328,7 +330,7 @@ class GlobalAttention(nn.Module):
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
) )
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias a += bias
a = self.softmax(a) a = self.softmax(a)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
......
...@@ -50,6 +50,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -50,6 +50,7 @@ class TemplatePointwiseAttention(nn.Module):
c_hidden, c_hidden,
no_heads, no_heads,
chunk_size, chunk_size,
inf,
**kwargs **kwargs
): ):
""" """
...@@ -68,6 +69,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -68,6 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.no_heads = no_heads self.no_heads = no_heads
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.inf = inf
self.mha = Attention( self.mha = Attention(
self.c_z, self.c_t, self.c_t, self.c_z, self.c_t, self.c_t,
...@@ -89,11 +91,11 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -89,11 +91,11 @@ class TemplatePointwiseAttention(nn.Module):
""" """
if(template_mask is None): if(template_mask is None):
# NOTE: This is not the "template_mask" from the supplement, but a # NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1, # [*, N_templ] mask from the code. I'm pretty sure it's always just
# but not sure enough to remove it. It's nice to have, I guess. # 1, but not sure enough to remove it. It's nice to have, I guess.
template_mask = t.new_ones(t.shape[:-3]) template_mask = t.new_ones(t.shape[:-3])
bias = (1e9 * (template_mask[..., None, None, None, None, :] - 1)) bias = (self.inf * (template_mask[..., None, None, None, None, :] - 1))
# [*, N_res, N_res, 1, C_z] # [*, N_res, N_res, 1, C_z]
z = z.unsqueeze(-2) z = z.unsqueeze(-2)
...@@ -133,6 +135,8 @@ class TemplatePairStackBlock(nn.Module): ...@@ -133,6 +135,8 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n, pair_transition_n,
dropout_rate, dropout_rate,
chunk_size, chunk_size,
inf,
**kwargs,
): ):
super(TemplatePairStackBlock, self).__init__() super(TemplatePairStackBlock, self).__init__()
...@@ -143,6 +147,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -143,6 +147,7 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n = pair_transition_n self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.inf = inf
self.dropout_row = DropoutRowwise(self.dropout_rate) self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate) self.dropout_col = DropoutColumnwise(self.dropout_rate)
...@@ -152,12 +157,14 @@ class TemplatePairStackBlock(nn.Module): ...@@ -152,12 +157,14 @@ class TemplatePairStackBlock(nn.Module):
self.c_hidden_tri_att, self.c_hidden_tri_att,
self.no_heads, self.no_heads,
chunk_size=chunk_size, chunk_size=chunk_size,
inf=inf,
) )
self.tri_att_end = TriangleAttentionEndingNode( self.tri_att_end = TriangleAttentionEndingNode(
self.c_t, self.c_t,
self.c_hidden_tri_att, self.c_hidden_tri_att,
self.no_heads, self.no_heads,
chunk_size=chunk_size, chunk_size=chunk_size,
inf=inf,
) )
self.tri_mul_out = TriangleMultiplicationOutgoing( self.tri_mul_out = TriangleMultiplicationOutgoing(
...@@ -200,6 +207,7 @@ class TemplatePairStack(nn.Module): ...@@ -200,6 +207,7 @@ class TemplatePairStack(nn.Module):
dropout_rate, dropout_rate,
blocks_per_ckpt, blocks_per_ckpt,
chunk_size, chunk_size,
inf=1e9,
**kwargs, **kwargs,
): ):
""" """
...@@ -237,6 +245,7 @@ class TemplatePairStack(nn.Module): ...@@ -237,6 +245,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n=pair_transition_n, pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
chunk_size=chunk_size, chunk_size=chunk_size,
inf=inf,
) )
self.blocks.append(block) self.blocks.append(block)
......
...@@ -57,11 +57,17 @@ class T: ...@@ -57,11 +57,17 @@ class T:
raise ValueError("Only one of rots and trans can be None") raise ValueError("Only one of rots and trans can be None")
elif(self.rots is None): elif(self.rots is None):
self.rots = T.identity_rot( self.rots = T.identity_rot(
self.trans.shape[:-1], self.trans.dtype, self.trans.device self.trans.shape[:-1],
self.trans.dtype,
self.trans.device,
self.trans.requires_grad,
) )
elif(self.trans is None): elif(self.trans is None):
self.trans = T.identity_trans( self.trans = T.identity_trans(
self.rots.shape[:-2], self.rots.dtype, self.rots.device self.rots.shape[:-2],
self.rots.dtype,
self.rots.device,
self.rots.requires_grad
) )
if(self.rots.shape[-2:] != (3, 3) or if(self.rots.shape[-2:] != (3, 3) or
...@@ -137,7 +143,7 @@ class T: ...@@ -137,7 +143,7 @@ class T:
return T(rots, trans) return T(rots, trans)
@staticmethod @staticmethod
def identity_rot(shape, dtype, device, requires_grad=False): def identity_rot(shape, dtype, device, requires_grad):
rots = torch.eye( rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad 3, dtype=dtype, device=device, requires_grad=requires_grad
) )
...@@ -147,7 +153,7 @@ class T: ...@@ -147,7 +153,7 @@ class T:
return rots return rots
@staticmethod @staticmethod
def identity_trans(shape, dtype, device, requires_grad=False): def identity_trans(shape, dtype, device, requires_grad):
trans = torch.zeros( trans = torch.zeros(
(*shape, 3), (*shape, 3),
dtype=dtype, dtype=dtype,
...@@ -182,20 +188,22 @@ class T: ...@@ -182,20 +188,22 @@ class T:
@staticmethod @staticmethod
def from_3_points(p_neg_x_axis, origin, p_xy_plane, eps=1e-8): def from_3_points(p_neg_x_axis, origin, p_xy_plane, eps=1e-8):
v1 = origin - p_neg_x_axis e0 = origin - p_neg_x_axis
v2 = p_xy_plane - origin e1 = p_xy_plane - origin
e1 = v1 / torch.sqrt(torch.sum(v1 ** 2, dim=-1) + eps)[..., None]
u2 = v2 - e1 * (torch.einsum('...i,...i->...', v2, e1)[..., None]) # Angle norming is very sensitive to floating point imprecisions
e2 = u2 / torch.sqrt(torch.sum(u2 ** 2, dim=-1) + eps)[..., None] #float_type = e0.dtype
e3 = torch.cross(e1, e2, dim=-1) #e0 = e0.float()
#e1 = e1.float()
rots = torch.cat(
( e0 = e0 / torch.sqrt(torch.sum(e0 ** 2, dim=-1, keepdims=True) + eps)
e1.unsqueeze(-1), e1 = e1 - e0 * torch.sum(e0 * e1, dim=-1, keepdims=True)
e2.unsqueeze(-1), e1 = e1 / torch.sqrt(torch.sum(e1 ** 2, dim=-1, keepdims=True) + eps)
e3.unsqueeze(-1), e2 = torch.cross(e0, e1)
), dim=-1,
) rots = torch.stack([e0, e1, e2], dim=-1)
#rots = rots.type(float_type)
return T(rots, origin) return T(rots, origin)
......
...@@ -70,6 +70,9 @@ def checkpoint_blocks( ...@@ -70,6 +70,9 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt): for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt e = s + blocks_per_ckpt
#print(len(args))
#for a in args:
# print(a.requires_grad)
args = checkpoint(chunker(s, e), *args) args = checkpoint(chunker(s, e), *args)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args) #args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args) args = wrap(args)
......
...@@ -286,7 +286,7 @@ def atom37_to_torsion_angles( ...@@ -286,7 +286,7 @@ def atom37_to_torsion_angles(
torch.square(torsion_angles_sin_cos), dim=-1, keepdims=True torch.square(torsion_angles_sin_cos), dim=-1, keepdims=True
) + eps ) + eps
) )
torsion_angles_sin_cos /= denom torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos *= torch.tensor( torsion_angles_sin_cos *= torch.tensor(
[1., 1., -1., 1., 1., 1., 1.], device=aatype.device, [1., 1., -1., 1., 1., 1., 1.], device=aatype.device,
...@@ -298,7 +298,7 @@ def atom37_to_torsion_angles( ...@@ -298,7 +298,7 @@ def atom37_to_torsion_angles(
mirror_torsion_angles = torch.cat( mirror_torsion_angles = torch.cat(
[ [
aatype.new_ones(*aatype.shape, 3), all_atom_mask.new_ones(*aatype.shape, 3),
1. - 2. * chi_is_ambiguous 1. - 2. * chi_is_ambiguous
], dim=-1 ], dim=-1
) )
......
...@@ -80,7 +80,7 @@ def compute_fape( ...@@ -80,7 +80,7 @@ def compute_fape(
positions_mask: torch.Tensor, positions_mask: torch.Tensor,
length_scale: float, length_scale: float,
l1_clamp_distance: Optional[float] = None, l1_clamp_distance: Optional[float] = None,
eps=1e-4 eps=1e-8
) -> torch.Tensor: ) -> torch.Tensor:
# [*, N_frames, N_pts, 3] # [*, N_frames, N_pts, 3]
local_pred_pos = pred_frames.invert()[..., None].apply( local_pred_pos = pred_frames.invert()[..., None].apply(
...@@ -100,12 +100,19 @@ def compute_fape( ...@@ -100,12 +100,19 @@ def compute_fape(
normed_error *= frames_mask[..., None] normed_error *= frames_mask[..., None]
normed_error *= positions_mask[..., None, :] normed_error *= positions_mask[..., None, :]
norm_factor = ( # FP16-friendly averaging. Roughly equivalent to:
torch.sum(frames_mask, dim=-1) * #
torch.sum(positions_mask, dim=-1) # norm_factor = (
) # torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) # )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
return normed_error return normed_error
...@@ -118,6 +125,7 @@ def backbone_loss( ...@@ -118,6 +125,7 @@ def backbone_loss(
use_clamped_fape: Optional[torch.Tensor] = None, use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10., clamp_distance: float = 10.,
loss_unit_distance: float = 10., loss_unit_distance: float = 10.,
eps: float = 1e-4,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
pred_aff = T.from_tensor(traj) pred_aff = T.from_tensor(traj)
...@@ -132,6 +140,7 @@ def backbone_loss( ...@@ -132,6 +140,7 @@ def backbone_loss(
backbone_affine_mask[..., None, :], backbone_affine_mask[..., None, :],
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps,
) )
if(use_clamped_fape is not None): if(use_clamped_fape is not None):
...@@ -144,6 +153,7 @@ def backbone_loss( ...@@ -144,6 +153,7 @@ def backbone_loss(
backbone_affine_mask[..., None, :], backbone_affine_mask[..., None, :],
l1_clamp_distance=None, l1_clamp_distance=None,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps,
) )
fape_loss = ( fape_loss = (
...@@ -167,6 +177,7 @@ def sidechain_loss( ...@@ -167,6 +177,7 @@ def sidechain_loss(
alt_naming_is_better: torch.Tensor, alt_naming_is_better: torch.Tensor,
clamp_distance: float = 10., clamp_distance: float = 10.,
length_scale: float = 10., length_scale: float = 10.,
eps: float = 1e-4,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
renamed_gt_frames = ( renamed_gt_frames = (
...@@ -200,7 +211,7 @@ def sidechain_loss( ...@@ -200,7 +211,7 @@ def sidechain_loss(
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view( renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(
*batch_dims, -1 *batch_dims, -1
) )
fape = compute_fape( fape = compute_fape(
sidechain_frames, sidechain_frames,
renamed_gt_frames, renamed_gt_frames,
...@@ -210,6 +221,7 @@ def sidechain_loss( ...@@ -210,6 +221,7 @@ def sidechain_loss(
renamed_atom14_gt_exists, renamed_atom14_gt_exists,
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=length_scale, length_scale=length_scale,
eps=eps,
) )
return fape return fape
...@@ -428,12 +440,16 @@ def distogram_loss( ...@@ -428,12 +440,16 @@ def distogram_loss(
square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
mean = ( # FP16-friendly sum. Equivalent to:
torch.sum(errors * square_mask, dim=(-1, -2)) / # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
(eps + torch.sum(square_mask, dim=(-1, -2))) # (eps + torch.sum(square_mask, dim=(-1, -2))))
) denom = eps + torch.sum(square_mask, dim=(-1, -2))
mean = errors * square_mask
mean = torch.sum(mean, dim=-1)
mean = mean / denom[..., None]
mean = torch.sum(mean, dim=-1)
return mean return mean
def tm_score( def tm_score(
...@@ -1285,10 +1301,18 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): ...@@ -1285,10 +1301,18 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
logits, logits,
torch.nn.functional.one_hot(true_msa, num_classes=23) torch.nn.functional.one_hot(true_msa, num_classes=23)
) )
loss = (
torch.sum(errors * bert_mask, dim=(-1, -2)) / # FP16-friendly averaging. Equivalent to:
(eps + torch.sum(bert_mask, dim=(-1, -2))) # loss = (
) # torch.sum(errors * bert_mask, dim=(-1, -2)) /
# (eps + torch.sum(bert_mask, dim=(-1, -2)))
# )
denom = eps + torch.sum(bert_mask, dim=(-1, -2))
loss = errors * bert_mask
loss = torch.sum(loss, dim=-1)
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
return loss return loss
...@@ -1298,9 +1322,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1298,9 +1322,7 @@ class AlphaFoldLoss(nn.Module):
super(AlphaFoldLoss, self).__init__() super(AlphaFoldLoss, self).__init__()
self.config = config self.config = config
def forward(self, out, batch): def forward(self, out, batch):
cum_loss = 0
if("violation" not in out.keys() and self.config.violation.weight): if("violation" not in out.keys() and self.config.violation.weight):
out["violation"] = find_structural_violations( out["violation"] = find_structural_violations(
batch, batch,
...@@ -1331,6 +1353,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1331,6 +1353,7 @@ class AlphaFoldLoss(nn.Module):
if("chi_angles_sin_cos" not in batch.keys()): if("chi_angles_sin_cos" not in batch.keys()):
batch.update(feats.atom37_to_torsion_angles( batch.update(feats.atom37_to_torsion_angles(
**batch, **batch,
eps=self.config.eps,
)) ))
# TODO: Verify that this is correct # TODO: Verify that this is correct
...@@ -1382,12 +1405,14 @@ class AlphaFoldLoss(nn.Module): ...@@ -1382,12 +1405,14 @@ class AlphaFoldLoss(nn.Module):
), ),
} }
cum_loss = 0
for k,loss_fn in loss_fns.items(): for k,loss_fn in loss_fns.items():
weight = self.config[k].weight weight = self.config[k].weight
if(weight): if(weight):
print(k)
loss = loss_fn() loss = loss_fn()
#print(k) print(weight * loss)
#print(loss) cum_loss = cum_loss + weight * loss
cum_loss += weight * loss
print(cum_loss)
return cum_loss return cum_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment