fape.py 5.06 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import ml_collections
import torch
from typing import Dict

from .geometry import compute_fape
from unifold.modules.frame import Frame


def backbone_loss(
    true_frame_tensor: torch.Tensor,
    frame_mask: torch.Tensor,
    traj: torch.Tensor,
    use_clamped_fape: torch.Tensor,
    clamp_distance: float = 10.0,
    loss_unit_distance: float = 10.0,
    clamp_distance_between_chains: float = 30.0,
    loss_unit_distance_between_chains: float = 20.0,
    intra_chain_mask: torch.Tensor = None,
    eps: float = 1e-4,
    **kwargs,
) -> torch.Tensor:
    pred_aff = Frame.from_tensor_4x4(traj)
    gt_aff = Frame.from_tensor_4x4(true_frame_tensor)

    use_clamped_fape = int(use_clamped_fape) == 1
    if intra_chain_mask is None:
        return compute_fape(
            pred_aff,
            gt_aff[None],
            frame_mask[None],
            pred_aff.get_trans(),
            gt_aff[None].get_trans(),
            frame_mask[None],
            pair_mask=None,
            l1_clamp_distance=clamp_distance if use_clamped_fape else None,
            length_scale=loss_unit_distance,
            eps=eps,
        )
    else:
        intra_chain_mask = intra_chain_mask.float().unsqueeze(0)
        intra_chain_bb_loss = compute_fape(
            pred_aff,
            gt_aff[None],
            frame_mask[None],
            pred_aff.get_trans(),
            gt_aff[None].get_trans(),
            frame_mask[None],
            pair_mask=intra_chain_mask,
            l1_clamp_distance=clamp_distance if use_clamped_fape else None,
            length_scale=loss_unit_distance,
            eps=eps,
        )
        interface_fape = compute_fape(
            pred_aff,
            gt_aff[None],
            frame_mask[None],
            pred_aff.get_trans(),
            gt_aff[None].get_trans(),
            frame_mask[None],
            pair_mask=1.0 - intra_chain_mask,
            l1_clamp_distance=clamp_distance_between_chains
            if use_clamped_fape
            else None,
            length_scale=loss_unit_distance_between_chains,
            eps=eps,
        )
        return intra_chain_bb_loss, interface_fape


def sidechain_loss(
    sidechain_frames: torch.Tensor,
    sidechain_atom_pos: torch.Tensor,
    rigidgroups_gt_frames: torch.Tensor,
    rigidgroups_alt_gt_frames: torch.Tensor,
    rigidgroups_gt_exists: torch.Tensor,
    renamed_atom14_gt_positions: torch.Tensor,
    renamed_atom14_gt_exists: torch.Tensor,
    alt_naming_is_better: torch.Tensor,
    clamp_distance: float = 10.0,
    length_scale: float = 10.0,
    eps: float = 1e-4,
    **kwargs,
) -> torch.Tensor:
    renamed_gt_frames = (
        1.0 - alt_naming_is_better[..., None, None, None]
    ) * rigidgroups_gt_frames + alt_naming_is_better[
        ..., None, None, None
    ] * rigidgroups_alt_gt_frames

    batch_dims = sidechain_frames.shape[:-4]
    sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
    sidechain_frames = Frame.from_tensor_4x4(sidechain_frames)

    renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
    renamed_gt_frames = Frame.from_tensor_4x4(renamed_gt_frames)
    rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
    sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
    renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(*batch_dims, -1, 3)
    renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)

    fape = compute_fape(
        sidechain_frames,
        renamed_gt_frames,
        rigidgroups_gt_exists,
        sidechain_atom_pos,
        renamed_atom14_gt_positions,
        renamed_atom14_gt_exists,
        pair_mask=None,
        l1_clamp_distance=clamp_distance,
        length_scale=length_scale,
        eps=eps,
    )
    return fape


def fape_loss(
    out: Dict[str, torch.Tensor],
    batch: Dict[str, torch.Tensor],
    config: ml_collections.ConfigDict,
    loss_dict: dict,
) -> torch.Tensor:
    for key in out["sm"]:
        out["sm"][key] = out["sm"][key].float()
    if "asym_id" in batch:
        intra_chain_mask = (
            batch["asym_id"][..., :, None] == batch["asym_id"][..., None, :]
        )
        bb_loss, interface_loss = backbone_loss(
            traj=out["sm"]["frames"],
            **{**batch, **config.backbone},
            intra_chain_mask=intra_chain_mask,
        )
        # only show the loss on last layer
        loss_dict["fape"] = bb_loss[-1].data
        loss_dict["interface_fape"] = interface_loss[-1].data
        bb_loss = torch.mean(bb_loss, dim=0) + torch.mean(interface_loss, dim=0)
    else:
        bb_loss = backbone_loss(
            traj=out["sm"]["frames"],
            **{**batch, **config.backbone},
            intra_chain_mask=None,
        )
        # only show the loss on last layer
        loss_dict["fape"] = bb_loss[-1].data
        bb_loss = torch.mean(bb_loss, dim=0)

    sc_loss = sidechain_loss(
        out["sm"]["sidechain_frames"],
        out["sm"]["positions"],
        **{**batch, **config.sidechain},
    )
    loss_dict["sc_fape"] = sc_loss.data
    loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss

    return loss