Unverified Commit 2ff294c6 authored by Stefan Doerr's avatar Stefan Doerr Committed by GitHub
Browse files

Fixing XTC/DCD time and step writing (#4879)

* add tests for correctness of step and time written in XTC and DCD

* improve tests

* improve xtc tests

* fix XTC/DCD time/step writing

* different approach by changing the reporters to not pass currentStep as firstStep but instead interval

* undo change
parent adf1843f
...@@ -145,8 +145,7 @@ class DCDFile(object): ...@@ -145,8 +145,7 @@ class DCDFile(object):
file.seek(8, os.SEEK_SET) file.seek(8, os.SEEK_SET)
file.write(struct.pack('<i', self._modelCount)) file.write(struct.pack('<i', self._modelCount))
file.seek(20, os.SEEK_SET) file.seek(20, os.SEEK_SET)
file.write(struct.pack('<i', self._firstStep+self._modelCount*self._interval)) file.write(struct.pack('<i', self._firstStep+(self._modelCount-1)*self._interval))
# Write the data. # Write the data.
file.seek(0, os.SEEK_END) file.seek(0, os.SEEK_END)
......
...@@ -110,7 +110,7 @@ class DCDReporter(object): ...@@ -110,7 +110,7 @@ class DCDReporter(object):
topology.addAtom(atoms[i].name, atoms[i].element, residue) topology.addAtom(atoms[i].name, atoms[i].element, residue)
self._dcd = DCDFile( self._dcd = DCDFile(
self._out, topology, simulation.integrator.getStepSize(), self._out, topology, simulation.integrator.getStepSize(),
simulation.currentStep, self._reportInterval, self._append self._reportInterval, self._reportInterval, self._append
) )
positions = state.getPositions(asNumpy=True) positions = state.getPositions(asNumpy=True)
if self._atomSubset is not None: if self._atomSubset is not None:
......
...@@ -107,7 +107,7 @@ class XTCReporter(object): ...@@ -107,7 +107,7 @@ class XTCReporter(object):
self._fileName, self._fileName,
topology, topology,
simulation.integrator.getStepSize(), simulation.integrator.getStepSize(),
simulation.currentStep, self._reportInterval,
self._reportInterval, self._reportInterval,
self._append, self._append,
) )
......
...@@ -6,6 +6,17 @@ from openmm import unit ...@@ -6,6 +6,17 @@ from openmm import unit
from random import random from random import random
import os import os
def _read_dcd_header(file):
import struct
with open(file, "r+b") as f:
f.seek(8, os.SEEK_SET)
modelCount = struct.unpack("<i", f.read(4))[0]
f.seek(20, os.SEEK_SET)
currStep = struct.unpack("<i", f.read(4))[0]
return modelCount, currStep
class TestDCDFile(unittest.TestCase): class TestDCDFile(unittest.TestCase):
def test_dcd(self): def test_dcd(self):
""" Test the DCD file """ """ Test the DCD file """
...@@ -49,11 +60,15 @@ class TestDCDFile(unittest.TestCase): ...@@ -49,11 +60,15 @@ class TestDCDFile(unittest.TestCase):
del simulation del simulation
del dcd del dcd
len1 = os.stat(fname).st_size len1 = os.stat(fname).st_size
modelCount, currStep = _read_dcd_header(fname)
self.assertEqual(5, modelCount)
self.assertEqual(10, currStep)
# Create a new simulation and have it append some more frames. # Create a new simulation and have it append some more frames.
integrator = mm.VerletIntegrator(0.001*unit.picoseconds) integrator = mm.VerletIntegrator(0.001*unit.picoseconds)
simulation = app.Simulation(pdb.topology, system, integrator, mm.Platform.getPlatform('Reference')) simulation = app.Simulation(pdb.topology, system, integrator, mm.Platform.getPlatform('Reference'))
simulation.currentStep = 10
dcd = app.DCDReporter(fname, 2, append=True) dcd = app.DCDReporter(fname, 2, append=True)
simulation.reporters.append(dcd) simulation.reporters.append(dcd)
simulation.context.setPositions(pdb.positions) simulation.context.setPositions(pdb.positions)
...@@ -64,6 +79,9 @@ class TestDCDFile(unittest.TestCase): ...@@ -64,6 +79,9 @@ class TestDCDFile(unittest.TestCase):
self.assertTrue(len2-len1 > 3*4*5*system.getNumParticles()) self.assertTrue(len2-len1 > 3*4*5*system.getNumParticles())
del simulation del simulation
del dcd del dcd
modelCount, currStep = _read_dcd_header(fname)
self.assertEqual(10, modelCount)
self.assertEqual(20, currStep)
os.remove(fname) os.remove(fname)
def testAtomSubset(self): def testAtomSubset(self):
...@@ -87,11 +105,15 @@ class TestDCDFile(unittest.TestCase): ...@@ -87,11 +105,15 @@ class TestDCDFile(unittest.TestCase):
del simulation del simulation
del dcd del dcd
len1 = os.stat(fname).st_size len1 = os.stat(fname).st_size
modelCount, currStep = _read_dcd_header(fname)
self.assertEqual(5, modelCount)
self.assertEqual(10, currStep)
# Create a new simulation and have it append some more frames. # Create a new simulation and have it append some more frames.
integrator = mm.VerletIntegrator(0.001*unit.picoseconds) integrator = mm.VerletIntegrator(0.001*unit.picoseconds)
simulation = app.Simulation(pdb.topology, system, integrator, mm.Platform.getPlatform('Reference')) simulation = app.Simulation(pdb.topology, system, integrator, mm.Platform.getPlatform('Reference'))
simulation.currentStep = 10
dcd = app.DCDReporter(fname, 2, append=True, atomSubset=atomSubset) dcd = app.DCDReporter(fname, 2, append=True, atomSubset=atomSubset)
simulation.reporters.append(dcd) simulation.reporters.append(dcd)
simulation.context.setPositions(pdb.positions) simulation.context.setPositions(pdb.positions)
...@@ -102,6 +124,9 @@ class TestDCDFile(unittest.TestCase): ...@@ -102,6 +124,9 @@ class TestDCDFile(unittest.TestCase):
self.assertTrue(len2-len1 > 3*4*5*len(atomSubset)) self.assertTrue(len2-len1 > 3*4*5*len(atomSubset))
del simulation del simulation
del dcd del dcd
modelCount, currStep = _read_dcd_header(fname)
self.assertEqual(10, modelCount)
self.assertEqual(20, currStep)
os.remove(fname) os.remove(fname)
......
...@@ -191,6 +191,8 @@ class TestXtcFile(unittest.TestCase): ...@@ -191,6 +191,8 @@ class TestXtcFile(unittest.TestCase):
) )
def testAppend(self): def testAppend(self):
from openmm.app.internal.xtc_utils import read_xtc
"""Test appending to an existing trajectory.""" """Test appending to an existing trajectory."""
with tempfile.TemporaryDirectory() as temp: with tempfile.TemporaryDirectory() as temp:
fname = os.path.join(temp, 'traj.xtc') fname = os.path.join(temp, 'traj.xtc')
...@@ -214,6 +216,9 @@ class TestXtcFile(unittest.TestCase): ...@@ -214,6 +216,9 @@ class TestXtcFile(unittest.TestCase):
simulation.step(10) simulation.step(10)
self.assertEqual(5, xtc._xtc._modelCount) self.assertEqual(5, xtc._xtc._modelCount)
self.assertEqual(5, xtc._xtc._getNumFrames()) self.assertEqual(5, xtc._xtc._getNumFrames())
_, _, time, step = read_xtc(fname.encode("utf-8"))
self.assertTrue(np.allclose(np.arange(2, 11, 2) * 0.001, time))
self.assertTrue(np.array_equal(np.arange(2, 11, 2), step))
del simulation del simulation
del xtc del xtc
...@@ -226,6 +231,7 @@ class TestXtcFile(unittest.TestCase): ...@@ -226,6 +231,7 @@ class TestXtcFile(unittest.TestCase):
integrator, integrator,
mm.Platform.getPlatform("Reference"), mm.Platform.getPlatform("Reference"),
) )
simulation.currentStep = 10
xtc = app.XTCReporter(fname, 2, append=True) xtc = app.XTCReporter(fname, 2, append=True)
simulation.reporters.append(xtc) simulation.reporters.append(xtc)
simulation.context.setPositions(pdb.positions) simulation.context.setPositions(pdb.positions)
...@@ -233,6 +239,9 @@ class TestXtcFile(unittest.TestCase): ...@@ -233,6 +239,9 @@ class TestXtcFile(unittest.TestCase):
simulation.step(10) simulation.step(10)
self.assertEqual(10, xtc._xtc._modelCount) self.assertEqual(10, xtc._xtc._modelCount)
self.assertEqual(10, xtc._xtc._getNumFrames()) self.assertEqual(10, xtc._xtc._getNumFrames())
_, _, time, step = read_xtc(fname.encode("utf-8"))
self.assertTrue(np.allclose(np.arange(2, 21, 2) * 0.001, time))
self.assertTrue(np.array_equal(np.arange(2, 21, 2), step))
del simulation del simulation
del xtc del xtc
...@@ -269,6 +278,8 @@ class TestXtcFile(unittest.TestCase): ...@@ -269,6 +278,8 @@ class TestXtcFile(unittest.TestCase):
self.assertEqual(box_read.shape, (3, 3, 5)) self.assertEqual(box_read.shape, (3, 3, 5))
self.assertEqual(len(time), 5) self.assertEqual(len(time), 5)
self.assertEqual(len(step), 5) self.assertEqual(len(step), 5)
self.assertTrue(np.allclose(np.arange(2, 11, 2) * 1e-10, time))
self.assertTrue(np.array_equal(np.arange(2, 11, 2), step))
coords = [pdb.positions[i].value_in_unit(unit.nanometers) for i in atomSubset] coords = [pdb.positions[i].value_in_unit(unit.nanometers) for i in atomSubset]
self.assertTrue(np.allclose(coords_read[:,:,0], coords, atol=1e-3)) self.assertTrue(np.allclose(coords_read[:,:,0], coords, atol=1e-3))
self.assertTrue(np.allclose(box_read[:,:,0], pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometers), atol=1e-3)) self.assertTrue(np.allclose(box_read[:,:,0], pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometers), atol=1e-3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment