Commit 2c4d4183 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

remove unecessary print statements

parent e4d7f6d2
......@@ -1610,7 +1610,6 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
Returns:
Masked MSA loss
"""
print(f"logits shape: {logits.shape} true_msa shape:{true_msa.shape}")
errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes)
)
......@@ -1881,13 +1880,12 @@ def greedy_align(
return align
def merge_labels(batch, per_asym_residue_index, labels, align):
def merge_labels(per_asym_residue_index, labels, align):
"""
batch:
labels: list of label dicts, each with shape [nk, *]
align: list of int, such as [2, None, 0, 1], each entry specify the corresponding label of the asym.
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
"""
num_res = batch["msa_mask"].shape[-1]
outs = {}
for k, v in labels[0].items():
cur_out = {}
......@@ -1904,11 +1902,7 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=dimension_to_merge)
print(f"k is {k} shape:{label.shape} and dimension_to_merge:{dimension_to_merge}")
outs[k] = new_v
print(f"finished merging")
for k,v in outs.items():
print(f"{k}:{v.shape}")
return outs
......@@ -2098,7 +2092,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment