Commit 0a0a741f authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix circular import

parent 1a341511
...@@ -24,14 +24,7 @@ from openfold.np import ( ...@@ -24,14 +24,7 @@ from openfold.np import (
protein, protein,
residue_constants, residue_constants,
) )
from openfold.utils.loss import ( import openfold.utils.loss as loss
find_structural_violations_np,
compute_violation_metrics_np,
)
find_structural_violations = find_structural_violations_np
compute_violation_metrics = compute_violation_metrics_np
from openfold.np.relax import cleanup, utils from openfold.np.relax import cleanup, utils
import ml_collections import ml_collections
import numpy as np import numpy as np
...@@ -343,14 +336,14 @@ def find_violations(prot_np: protein.Protein): ...@@ -343,14 +336,14 @@ def find_violations(prot_np: protein.Protein):
batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32) batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
batch = make_atom14_positions(batch) batch = make_atom14_positions(batch)
violations = find_structural_violations( violations = loss.find_structural_violations_np(
batch=batch, batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"], atom14_pred_positions=batch["atom14_gt_positions"],
config=ml_collections.ConfigDict( config=ml_collections.ConfigDict(
{"violation_tolerance_factor": 12, # Taken from model config. {"violation_tolerance_factor": 12, # Taken from model config.
"clash_overlap_tolerance": 1.5, # Taken from model config. "clash_overlap_tolerance": 1.5, # Taken from model config.
})) }))
violation_metrics = compute_violation_metrics( violation_metrics = loss.compute_violation_metrics_np(
batch=batch, batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"], atom14_pred_positions=batch["atom14_gt_positions"],
violations=violations, violations=violations,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment