"examples/vscode:/vscode.git/clone" did not exist on "942a0fb9100ac006c67cbf42807917ea9863cbcf"
Commit 2c4d4183 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

remove unecessary print statements

parent e4d7f6d2
...@@ -1610,7 +1610,6 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs ...@@ -1610,7 +1610,6 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
Returns: Returns:
Masked MSA loss Masked MSA loss
""" """
print(f"logits shape: {logits.shape} true_msa shape:{true_msa.shape}")
errors = softmax_cross_entropy( errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes) logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes)
) )
...@@ -1881,13 +1880,12 @@ def greedy_align( ...@@ -1881,13 +1880,12 @@ def greedy_align(
return align return align
def merge_labels(batch, per_asym_residue_index, labels, align): def merge_labels(per_asym_residue_index, labels, align):
""" """
batch: per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of label dicts, each with shape [nk, *] labels: list of original ground truth feats
align: list of int, such as [2, None, 0, 1], each entry specify the corresponding label of the asym. align: list of tuples, each entry specify the corresponding label of the asym.
""" """
num_res = batch["msa_mask"].shape[-1]
outs = {} outs = {}
for k, v in labels[0].items(): for k, v in labels[0].items():
cur_out = {} cur_out = {}
...@@ -1904,11 +1902,7 @@ def merge_labels(batch, per_asym_residue_index, labels, align): ...@@ -1904,11 +1902,7 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0: if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=dimension_to_merge) new_v = torch.concat(cur_out, dim=dimension_to_merge)
print(f"k is {k} shape:{label.shape} and dimension_to_merge:{dimension_to_merge}")
outs[k] = new_v outs[k] = new_v
print(f"finished merging")
for k,v in outs.items():
print(f"{k}:{v.shape}")
return outs return outs
...@@ -2098,7 +2092,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2098,7 +2092,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
gc.collect() gc.collect()
print(f"finished multi-chain permutation and final align is {align}") print(f"finished multi-chain permutation and final align is {align}")
merged_labels = merge_labels( merged_labels = merge_labels(
batch,
per_asym_residue_index, per_asym_residue_index,
labels, labels,
align, align,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment