Unverified Commit 3d87ef2f authored by Dingquan Yu's avatar Dingquan Yu Committed by GitHub
Browse files

Merge pull request #3 from dingquanyu/modify-assignment-stage

Modify assignment stage
parents 2a70e080 eeb035c2
......@@ -1677,7 +1677,7 @@ def chain_center_of_mass_loss(
# #
def kabsch_rotation(P, Q):
"""
Use scipy.spatial package to calculate best rotation that minimises
Use procrustes package to calculate best rotation that minimises
the RMSD betwee P and Q
The optimal rotation matrix was calculated using
......@@ -1755,19 +1755,6 @@ def compute_rmsd(
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps)
def kabsch_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor,
):
r, x = get_optimal_transform(
true_atom_pos,
pred_atom_pos,
atom_mask,
)
aligned_true_atom_pos = true_atom_pos @ r + x
return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask)
def get_least_asym_entity_or_longest_length(batch):
"""
......@@ -1802,6 +1789,11 @@ def get_least_asym_entity_or_longest_length(batch):
least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym)
return least_asym_entities[0], best_pred_asym
......@@ -2032,21 +2024,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym is chosen to be: {anchor_gt_asym}")
print(f"anchor_gt_asym is : {anchor_gt_asym} and anchor_pred_asym is {anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1
best_rmsd = 1e20
best_labels = None
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
for cur_asym_id in anchor_pred_asym:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
anchor_pred_pos = pred_ca_pos[asym_mask]
......@@ -2059,15 +2047,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
)
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times):
shuffle_idx = torch.randperm(
unique_asym_ids.shape[0], device=unique_asym_ids.device
)
shuffled_asym_ids = unique_asym_ids[shuffle_idx]
align = greedy_align(
batch,
per_asym_residue_index,
shuffled_asym_ids,
unique_asym_ids ,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
......@@ -2080,17 +2063,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels,
align,
)
rmsd = kabsch_rmsd(
merged_labels["all_atom_positions"][..., ca_idx, :].to('cpu') @ r.to('cpu') + x.to('cpu'),
pred_ca_pos,
(pred_ca_mask.to('cpu') * merged_labels["all_atom_mask"][..., ca_idx].to('cpu')).bool(),
)
if rmsd < best_rmsd:
best_rmsd = rmsd
best_labels = merged_labels
print(f"finished shuffling and final align is {align}")
return best_labels
print(f"finished multi-chain permutation and final align is {align}")
return merged_labels
def forward(self,out,batch,_return_breakdown=False):
"""
......@@ -2107,6 +2083,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels)
logger.info("finished multi-chain permutation")
# features.update(permutated_labels)
# self.loss(out,features)
return permutated_labels
## TODO next need to check how the ground truth label is used
# in loss calculation.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment