Unverified Commit f6e6b6ed authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimizations to reporters (#4330)

* Optimizations to reporters

* Removed unneeded imports
parent 127a3733
......@@ -121,9 +121,10 @@ class DCDFile(object):
raise ValueError('The number of positions must match the number of atoms')
if is_quantity(positions):
positions = positions.value_in_unit(nanometers)
if any(math.isnan(norm(pos)) for pos in positions):
import numpy as np
if np.isnan(positions).any():
raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
if any(math.isinf(norm(pos)) for pos in positions):
if np.isinf(positions).any():
raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
file = self._file
......
......@@ -104,7 +104,7 @@ class DCDReporter(object):
self._out, simulation.topology, simulation.integrator.getStepSize(),
simulation.currentStep, self._reportInterval, self._append
)
self._dcd.writeModel(state.getPositions(), periodicBoxVectors=state.getPeriodicBoxVectors())
self._dcd.writeModel(state.getPositions(asNumpy=True), periodicBoxVectors=state.getPeriodicBoxVectors())
def __del__(self):
self._out.close()
......@@ -343,9 +343,10 @@ class PDBFile(object):
raise ValueError('The number of positions must match the number of atoms')
if is_quantity(positions):
positions = positions.value_in_unit(angstroms)
if any(math.isnan(norm(pos)) for pos in positions):
import numpy as np
if np.isnan(positions).any():
raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
if any(math.isinf(norm(pos)) for pos in positions):
if np.isinf(positions).any():
raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
nonHeterogens = PDBFile._standardResidues[:]
nonHeterogens.remove('HOH')
......
......@@ -104,12 +104,12 @@ class PDBReporter(object):
topology = self._subsetTopology
#PDBFile will convert to angstroms so do it here first instead
positions = state.getPositions().value_in_unit(angstroms)
positions = state.getPositions(asNumpy=True).value_in_unit(angstroms)
positions = [positions[i] for i in self._atomSubset]
else:
topology = simulation.topology
positions = state.getPositions()
positions = state.getPositions(asNumpy=True)
if self._nextModel == 0:
PDBFile.writeHeader(topology, self._out)
......@@ -202,12 +202,12 @@ class PDBxReporter(PDBReporter):
topology = self._subsetTopology
#PDBFile will convert to angstroms so do it here first instead
positions = state.getPositions().value_in_unit(angstroms)
positions = state.getPositions(asNumpy=True).value_in_unit(angstroms)
positions = [positions[i] for i in self._atomSubset]
else:
topology = simulation.topology
positions = state.getPositions()
positions = state.getPositions(asNumpy=True)
if self._nextModel == 0:
PDBxFile.writeHeader(topology, self._out)
......
......@@ -418,9 +418,10 @@ class PDBxFile(object):
raise ValueError('The number of positions must match the number of atoms')
if is_quantity(positions):
positions = positions.value_in_unit(angstroms)
if any(math.isnan(norm(pos)) for pos in positions):
import numpy as np
if np.isnan(positions).any():
raise ValueError('Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
if any(math.isinf(norm(pos)) for pos in positions):
if np.isinf(positions).any():
raise ValueError('Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan')
nonHeterogens = PDBFile._standardResidues[:]
nonHeterogens.remove('HOH')
......
......@@ -9,11 +9,9 @@ from openmm.app.internal.xtc_utils import (
get_xtc_nframes,
get_xtc_natoms,
)
import numpy as np
import os
from openmm import Vec3
from openmm.unit import nanometers, picoseconds, is_quantity, norm
import math
import tempfile
import shutil
......@@ -92,11 +90,12 @@ class XTCFile(object):
raise ValueError("The number of positions must match the number of atoms")
if is_quantity(positions):
positions = positions.value_in_unit(nanometers)
if any(math.isnan(norm(pos)) for pos in positions):
import numpy as np
if np.isnan(positions).any():
raise ValueError(
"Particle position is NaN. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan"
)
if any(math.isinf(norm(pos)) for pos in positions):
if np.isinf(positions).any():
raise ValueError(
"Particle position is infinite. For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan"
)
......
......@@ -71,5 +71,5 @@ class XTCReporter(object):
self._append,
)
self._xtc.writeModel(
state.getPositions(), periodicBoxVectors=state.getPeriodicBoxVectors()
state.getPositions(asNumpy=True), periodicBoxVectors=state.getPeriodicBoxVectors()
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment