Commit 4e382310 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

remove some print statements

parent d7162be4
......@@ -980,7 +980,6 @@ def between_residue_clash_loss(
shape (N, 14)
"""
fp_type = atom14_pred_positions.dtype
# Create the distance matrix.
# (N, N, 14, 14)
dists = torch.sqrt(
......@@ -1234,7 +1233,7 @@ def find_structural_violations(
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
torch.cuda.memory_summary()
# Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
......@@ -1751,7 +1750,7 @@ def get_optimal_transform(
del src_atoms,tgt_atoms,
gc.collect()
tgt_center,src_center = tgt_center.to('cuda:0'),src_center.to('cuda:0')
tgt_center,src_center = tgt_center.to('cuda'),src_center.to('cuda')
x = tgt_center.to('cpu') - src_center.to('cpu') @ r.to('cpu')
del tgt_center,src_center,mask
......@@ -2055,11 +2054,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_residue_idx)
anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
# anchor_pred_pos = anchor_pred_pos.to('cuda')
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_residue_idx)
anchor_pred_mask =pred_ca_mask[asym_mask]
anchor_pred_mask =pred_ca_mask[0][asym_mask[0]]
# anchor_pred_mask = anchor_pred_mask.to('cuda')
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform(
......@@ -2085,6 +2084,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del aligned_true_ca_poses
del r,x
del pred_ca_pos,pred_ca_mask,true_ca_poses,true_ca_masks
del anchor_pred_pos,anchor_true_pos
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
merged_labels = merge_labels(
......
......@@ -330,7 +330,6 @@ def main(args):
low_prec=(str(args.precision) == "16")
)
if "multimer" in args.config_preset:
print("training multimer models now")
model_module = OpenFoldMultimerWrapper(config)
else:
model_module = OpenFoldWrapper(config)
......@@ -360,7 +359,6 @@ def main(args):
#data_module = DummyDataLoader("new_batch.pickle")
if "multimer" in args.config_preset:
print("use multimer datamodule now")
data_module = OpenFoldMultimerDataModule(
config=config.data,
batch_seed=args.seed,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment