Unverified Commit 3d87ef2f authored by Dingquan Yu's avatar Dingquan Yu Committed by GitHub
Browse files

Merge pull request #3 from dingquanyu/modify-assignment-stage

Modify assignment stage
parents 2a70e080 eeb035c2
...@@ -1677,7 +1677,7 @@ def chain_center_of_mass_loss( ...@@ -1677,7 +1677,7 @@ def chain_center_of_mass_loss(
# # # #
def kabsch_rotation(P, Q): def kabsch_rotation(P, Q):
""" """
Use scipy.spatial package to calculate best rotation that minimises Use procrustes package to calculate best rotation that minimises
the RMSD betwee P and Q the RMSD betwee P and Q
The optimal rotation matrix was calculated using The optimal rotation matrix was calculated using
...@@ -1755,19 +1755,6 @@ def compute_rmsd( ...@@ -1755,19 +1755,6 @@ def compute_rmsd(
msd = torch.nan_to_num(msd, nan=1e8) msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps) return torch.sqrt(msd + eps)
def kabsch_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor,
):
r, x = get_optimal_transform(
true_atom_pos,
pred_atom_pos,
atom_mask,
)
aligned_true_atom_pos = true_atom_pos @ r + x
return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask)
def get_least_asym_entity_or_longest_length(batch): def get_least_asym_entity_or_longest_length(batch):
""" """
...@@ -1802,6 +1789,11 @@ def get_least_asym_entity_or_longest_length(batch): ...@@ -1802,6 +1789,11 @@ def get_least_asym_entity_or_longest_length(batch):
least_asym_entities = random.choice(least_asym_entities) least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1 assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]]) best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym)
return least_asym_entities[0], best_pred_asym return least_asym_entities[0], best_pred_asym
...@@ -2032,65 +2024,49 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2032,65 +2024,49 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask] per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym is chosen to be: {anchor_gt_asym}") print(f"anchor_gt_asym is : {anchor_gt_asym} and anchor_pred_asym is {anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
best_rmsd = 1e20
best_labels = None
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {} entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids: for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask]) cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
for cur_asym_id in anchor_pred_asym: asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
asym_mask = (batch["asym_id"] == cur_asym_id).bool() anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]
anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx] anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_pred_pos = pred_ca_pos[asym_mask] anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx] anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_pred_mask = pred_ca_mask[asym_mask] r, x = get_optimal_transform(
r, x = get_optimal_transform( anchor_true_pos,
anchor_true_pos, anchor_pred_pos,
anchor_pred_pos, (anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(), )
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
align = greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids ,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
) )
merged_labels = merge_labels(
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms batch,
for _ in range(shuffle_times): per_asym_residue_index,
shuffle_idx = torch.randperm( labels,
unique_asym_ids.shape[0], device=unique_asym_ids.device align,
) )
shuffled_asym_ids = unique_asym_ids[shuffle_idx]
align = greedy_align( print(f"finished multi-chain permutation and final align is {align}")
batch,
per_asym_residue_index, return merged_labels
shuffled_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
)
rmsd = kabsch_rmsd(
merged_labels["all_atom_positions"][..., ca_idx, :].to('cpu') @ r.to('cpu') + x.to('cpu'),
pred_ca_pos,
(pred_ca_mask.to('cpu') * merged_labels["all_atom_mask"][..., ca_idx].to('cpu')).bool(),
)
if rmsd < best_rmsd:
best_rmsd = rmsd
best_labels = merged_labels
print(f"finished shuffling and final align is {align}")
return best_labels
def forward(self,out,batch,_return_breakdown=False): def forward(self,out,batch,_return_breakdown=False):
""" """
...@@ -2107,6 +2083,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2107,6 +2083,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# then permutate ground truth chains before calculating the loss # then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels) permutated_labels = self.multi_chain_perm_align(out,features,labels)
logger.info("finished multi-chain permutation") logger.info("finished multi-chain permutation")
# features.update(permutated_labels)
# self.loss(out,features)
return permutated_labels return permutated_labels
## TODO next need to check how the ground truth label is used ## TODO next need to check how the ground truth label is used
# in loss calculation. # in loss calculation.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment