Unverified Commit f6e6b6ed authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimizations to reporters (#4330)

* Optimizations to reporters

* Removed unneeded imports
parent 127a3733
...@@ -121,9 +121,10 @@ class DCDFile(object): ...@@ -121,9 +121,10 @@ class DCDFile(object):
raise ValueError('The number of positions must match the number of atoms') raise ValueError('The number of positions must match the number of atoms')
if is_quantity(positions): if is_quantity(positions):
positions = positions.value_in_unit(nanometers) positions = positions.value_in_unit(nanometers)
if any(math.isnan(norm(pos)) for pos in positions): import numpy as np
if np.isnan(positions).any():
raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan') raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
if any(math.isinf(norm(pos)) for pos in positions): if np.isinf(positions).any():
raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan') raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
file = self._file file = self._file
......
...@@ -104,7 +104,7 @@ class DCDReporter(object): ...@@ -104,7 +104,7 @@ class DCDReporter(object):
self._out, simulation.topology, simulation.integrator.getStepSize(), self._out, simulation.topology, simulation.integrator.getStepSize(),
simulation.currentStep, self._reportInterval, self._append simulation.currentStep, self._reportInterval, self._append
) )
self._dcd.writeModel(state.getPositions(), periodicBoxVectors=state.getPeriodicBoxVectors()) self._dcd.writeModel(state.getPositions(asNumpy=True), periodicBoxVectors=state.getPeriodicBoxVectors())
def __del__(self): def __del__(self):
self._out.close() self._out.close()
...@@ -343,9 +343,10 @@ class PDBFile(object): ...@@ -343,9 +343,10 @@ class PDBFile(object):
raise ValueError('The number of positions must match the number of atoms') raise ValueError('The number of positions must match the number of atoms')
if is_quantity(positions): if is_quantity(positions):
positions = positions.value_in_unit(angstroms) positions = positions.value_in_unit(angstroms)
if any(math.isnan(norm(pos)) for pos in positions): import numpy as np
if np.isnan(positions).any():
raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan') raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
if any(math.isinf(norm(pos)) for pos in positions): if np.isinf(positions).any():
raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan') raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
nonHeterogens = PDBFile._standardResidues[:] nonHeterogens = PDBFile._standardResidues[:]
nonHeterogens.remove('HOH') nonHeterogens.remove('HOH')
......
...@@ -104,12 +104,12 @@ class PDBReporter(object): ...@@ -104,12 +104,12 @@ class PDBReporter(object):
topology = self._subsetTopology topology = self._subsetTopology
#PDBFile will convert to angstroms so do it here first instead #PDBFile will convert to angstroms so do it here first instead
positions = state.getPositions().value_in_unit(angstroms) positions = state.getPositions(asNumpy=True).value_in_unit(angstroms)
positions = [positions[i] for i in self._atomSubset] positions = [positions[i] for i in self._atomSubset]
else: else:
topology = simulation.topology topology = simulation.topology
positions = state.getPositions() positions = state.getPositions(asNumpy=True)
if self._nextModel == 0: if self._nextModel == 0:
PDBFile.writeHeader(topology, self._out) PDBFile.writeHeader(topology, self._out)
...@@ -202,12 +202,12 @@ class PDBxReporter(PDBReporter): ...@@ -202,12 +202,12 @@ class PDBxReporter(PDBReporter):
topology = self._subsetTopology topology = self._subsetTopology
#PDBFile will convert to angstroms so do it here first instead #PDBFile will convert to angstroms so do it here first instead
positions = state.getPositions().value_in_unit(angstroms) positions = state.getPositions(asNumpy=True).value_in_unit(angstroms)
positions = [positions[i] for i in self._atomSubset] positions = [positions[i] for i in self._atomSubset]
else: else:
topology = simulation.topology topology = simulation.topology
positions = state.getPositions() positions = state.getPositions(asNumpy=True)
if self._nextModel == 0: if self._nextModel == 0:
PDBxFile.writeHeader(topology, self._out) PDBxFile.writeHeader(topology, self._out)
......
...@@ -418,9 +418,10 @@ class PDBxFile(object): ...@@ -418,9 +418,10 @@ class PDBxFile(object):
raise ValueError('The number of positions must match the number of atoms') raise ValueError('The number of positions must match the number of atoms')
if is_quantity(positions): if is_quantity(positions):
positions = positions.value_in_unit(angstroms) positions = positions.value_in_unit(angstroms)
if any(math.isnan(norm(pos)) for pos in positions): import numpy as np
if np.isnan(positions).any():
raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan') raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
if any(math.isinf(norm(pos)) for pos in positions): if np.isinf(positions).any():
raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan') raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
nonHeterogens = PDBFile._standardResidues[:] nonHeterogens = PDBFile._standardResidues[:]
nonHeterogens.remove('HOH') nonHeterogens.remove('HOH')
......
...@@ -9,11 +9,9 @@ from openmm.app.internal.xtc_utils import ( ...@@ -9,11 +9,9 @@ from openmm.app.internal.xtc_utils import (
get_xtc_nframes, get_xtc_nframes,
get_xtc_natoms, get_xtc_natoms,
) )
import numpy as np
import os import os
from openmm import Vec3 from openmm import Vec3
from openmm.unit import nanometers, picoseconds, is_quantity, norm from openmm.unit import nanometers, picoseconds, is_quantity, norm
import math
import tempfile import tempfile
import shutil import shutil
...@@ -92,11 +90,12 @@ class XTCFile(object): ...@@ -92,11 +90,12 @@ class XTCFile(object):
raise ValueError("The number of positions must match the number of atoms") raise ValueError("The number of positions must match the number of atoms")
if is_quantity(positions): if is_quantity(positions):
positions = positions.value_in_unit(nanometers) positions = positions.value_in_unit(nanometers)
if any(math.isnan(norm(pos)) for pos in positions): import numpy as np
if np.isnan(positions).any():
raise ValueError( raise ValueError(
"Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan" "Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan"
) )
if any(math.isinf(norm(pos)) for pos in positions): if np.isinf(positions).any():
raise ValueError( raise ValueError(
"Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan" "Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan"
) )
......
...@@ -71,5 +71,5 @@ class XTCReporter(object): ...@@ -71,5 +71,5 @@ class XTCReporter(object):
self._append, self._append,
) )
self._xtc.writeModel( self._xtc.writeModel(
state.getPositions(), periodicBoxVectors=state.getPeriodicBoxVectors() state.getPositions(asNumpy=True), periodicBoxVectors=state.getPeriodicBoxVectors()
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment