Commit d3b2b265 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update codes

parent d4b6163d
......@@ -28,6 +28,7 @@ def kabsch_rotation(P, Q):
"""
# Computation of the covariance matrix
P,Q = P.to('cpu'),Q.to('cpu') # move to cpu memory just in case it takes up too much gpu mem
C = P.transpose(-1, -2) @ Q
# Computation of the optimal rotation matrix
......@@ -66,6 +67,7 @@ def get_optimal_transform(
src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True)
r = kabsch_rotation(src_atoms - src_center, tgt_atoms - tgt_center)
tgt_center,src_center = tgt_center.to('cpu'),src_center.to('cpu') # load to cpu memory just in case
x = tgt_center - src_center @ r
return r, x
......@@ -158,7 +160,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
)
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times):
shuffle_idx = torch.randperm(
unique_asym_ids.shape[0], device=unique_asym_ids.device
......@@ -260,13 +262,17 @@ def greedy_align(
best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask]
cur_pred_pos = pred_ca_pos[asym_mask[:,0],:] # only need the first 1 column of asym_mask
print(f"line 266 cur_pred_pos shape: {cur_pred_pos.shape} and pred_ca_pos is {pred_ca_pos.shape}")
cur_pred_mask = pred_ca_mask[asym_mask[:,0]]# only need the first column of asym_mask
for next_asym_id in cur_asym_list:
if next_asym_id == 0:
continue
j = int(next_asym_id - 1)
if not used[j]: # posesible candidate
if not used[j]: # possible candidate
print(f"line 265 curr_residue_index is {cur_residue_index} and j is {j}")
print(f"true_ca_poses shape: {true_ca_poses[j].shape}")
cropped_pos = true_ca_poses[j][cur_residue_index]
mask = true_ca_masks[j][cur_residue_index]
rmsd = compute_rmsd(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment