"deploy/examples/llm/components/processor.py" did not exist on "1af7433bffac503dc3ecbb6834f4baf6e9358c33"
Commit faca088f authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

split multi_chain_perm_align into multiple smaller functions

parent be127915
...@@ -1834,7 +1834,6 @@ def get_least_asym_entity_or_longest_length(batch): ...@@ -1834,7 +1834,6 @@ def get_least_asym_entity_or_longest_length(batch):
def greedy_align( def greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
pred_ca_mask, pred_ca_mask,
...@@ -1847,6 +1846,7 @@ def greedy_align( ...@@ -1847,6 +1846,7 @@ def greedy_align(
""" """
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))]
align = [] align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1) i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id asym_mask = batch["asym_id"] == cur_asym_id
...@@ -2064,6 +2064,53 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2064,6 +2064,53 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES]))) labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels return labels
@staticmethod
def get_per_asym_residue_idex(batch):
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
return per_asym_residue_index
@staticmethod
def get_entity_2_asym_list(batch):
entity_2_asym_list = {}
unique_entity_ids = torch.unique(batch["entity_id"])
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list
@staticmethod
def calculate_input_mask(true_ca_masks,anchor_gt_idx,
asym_mask,pred_ca_mask,
anchor_residue_idx):
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_residue_idx)
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask
@staticmethod
def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_residue_idx,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos):
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
anchor_gt_idx,asym_mask,
pred_ca_mask,anchor_residue_idx)
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx)
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
r, x = get_optimal_transform(
anchor_pred_pos, anchor_true_pos[0],
mask=input_mask[0]
)
return r, x
@staticmethod @staticmethod
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=True): def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=True):
""" """
...@@ -2080,54 +2127,35 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2080,54 +2127,35 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3] pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres] pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [ true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,]) ] # list([nres,])
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains: if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}") print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool() asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_idex(batch)
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)] anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx)
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_residue_idx) true_ca_poses = [l["all_atom_positions"][..., ca_idx, :] for l in labels] # list([nres, 3])
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]]
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_residue_idx,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos)
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform(
anchor_pred_pos, anchor_true_pos[0],
mask=input_mask[0]
)
del input_mask # just to save memory
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses del true_ca_poses
gc.collect() gc.collect()
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
align = greedy_align( align = greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
pred_ca_mask, pred_ca_mask,
...@@ -2138,7 +2166,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2138,7 +2166,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del aligned_true_ca_poses, true_ca_masks del aligned_true_ca_poses, true_ca_masks
del r, x del r, x
del pred_ca_pos, pred_ca_mask del pred_ca_pos, pred_ca_mask
del anchor_pred_pos, anchor_true_pos
gc.collect() gc.collect()
print(f"finished multi-chain permutation and final align is {align}") print(f"finished multi-chain permutation and final align is {align}")
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment