Commit 8f1ebd58 authored by Ali Cowen-Rivers's avatar Ali Cowen-Rivers Committed by Copybara-Service
Browse files

Fix GPU relax for longer chains by pinning large memory ops to cpu.

PiperOrigin-RevId: 501105389
Change-Id: I6c981d1d3231e008ebae192edb4586479eb5eb34
parent 420fb08f
...@@ -26,6 +26,7 @@ from alphafold.relax import cleanup ...@@ -26,6 +26,7 @@ from alphafold.relax import cleanup
from alphafold.relax import utils from alphafold.relax import utils
import ml_collections import ml_collections
import numpy as np import numpy as np
import jax
from simtk import openmm from simtk import openmm
from simtk import unit from simtk import unit
from simtk.openmm import app as openmm_app from simtk.openmm import app as openmm_app
...@@ -486,7 +487,9 @@ def run_pipeline( ...@@ -486,7 +487,9 @@ def run_pipeline(
pdb_string = clean_protein(prot, checks=True) pdb_string = clean_protein(prot, checks=True)
else: else:
pdb_string = ret["min_pdb"] pdb_string = ret["min_pdb"]
ret.update(get_violation_metrics(prot)) # Calculation of violations can cause CUDA errors for some JAX versions.
with jax.default_device(jax.devices("cpu")[0]):
ret.update(get_violation_metrics(prot))
ret.update({ ret.update({
"num_exclusions": len(exclude_residues), "num_exclusions": len(exclude_residues),
"iteration": iteration, "iteration": iteration,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment