Commit b7088b74 authored by peastman's avatar peastman Committed by Robert McGibbon
Browse files

Python 2/3 compatibility in single code base, plus python 3 testing on travis.

parent 4c00b312
......@@ -12,6 +12,7 @@
Test implementing PDBx/mmCIF write and formatting operations.
"""
from __future__ import absolute_import
__docformat__ = "restructuredtext en"
__author__ = "John Westbrook"
__email__ = "jwest@rcsb.rutgers.edu"
......
......@@ -28,6 +28,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
__author__ = "Peter Eastman"
__version__ = "1.0"
......
......@@ -29,6 +29,7 @@ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import division
from __future__ import absolute_import
__author__ = "Peter Eastman"
__version__ = "1.0"
......@@ -40,7 +41,7 @@ from simtk.openmm.vec3 import Vec3
from simtk.openmm import System, Context, NonbondedForce, CustomNonbondedForce, HarmonicBondForce, HarmonicAngleForce, VerletIntegrator, LocalEnergyMinimizer
from simtk.unit import nanometer, molar, elementary_charge, amu, gram, liter, degree, sqrt, acos, is_quantity, dot, norm
import simtk.unit as unit
import element as elem
from . import element as elem
import os
import random
import xml.etree.ElementTree as etree
......@@ -877,7 +878,7 @@ class Modeller(object):
# Create copies of all residue templates that have had all extra points removed.
templatesNoEP = {}
for resName, template in forcefield._templates.iteritems():
for resName, template in forcefield._templates.items():
if any(atom.element is None for atom in template.atoms):
index = 0
newIndex = {}
......
......@@ -28,6 +28,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
__author__ = "Peter Eastman"
__version__ = "1.0"
......
......@@ -28,6 +28,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
__author__ = "Peter Eastman"
__version__ = "1.0"
......@@ -133,7 +134,7 @@ class Simulation(object):
def _simulate(self, endStep=None, endTime=None):
if endStep is None:
endStep = sys.maxint
endStep = sys.maxsize
nextReport = [None]*len(self.reporters)
while self.currentStep < endStep and (endTime is None or datetime.now() < endTime):
nextSteps = endStep-self.currentStep
......
......@@ -28,6 +28,8 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
from __future__ import print_function
__author__ = "Peter Eastman"
__version__ = "1.0"
......@@ -146,7 +148,7 @@ class StateDataReporter(object):
if not self._hasInitialized:
self._initializeConstants(simulation)
headers = self._constructHeaders()
print >>self._out, '#"%s"' % ('"'+self._separator+'"').join(headers)
print('#"%s"' % ('"'+self._separator+'"').join(headers), file=self._out)
try:
self._out.flush()
except AttributeError:
......@@ -163,7 +165,7 @@ class StateDataReporter(object):
values = self._constructReportValues(simulation, state)
# Write the values.
print >>self._out, self._separator.join(str(v) for v in values)
print(self._separator.join(str(v) for v in values), file=self._out)
try:
self._out.flush()
except AttributeError:
......
......@@ -28,6 +28,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
__author__ = "Peter Eastman"
__version__ = "1.0"
......
......@@ -28,6 +28,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
__author__ = "Peter Eastman"
__version__ = "1.0"
......
......@@ -28,6 +28,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
__author__ = "Peter Eastman"
__version__ = "1.0"
......
from __future__ import print_function
from __future__ import absolute_import
from functools import wraps
import os
import sys
......
"""
Physical quantities with units for dimensional analysis and automatic unit conversion.
"""
__docformat__ = "epytext en"
__author__ = "Christopher M. Bruns"
__copyright__ = "Copyright 2010, Stanford University and Christopher M. Bruns"
__credits__ = []
__license__ = "MIT"
__maintainer__ = "Christopher M. Bruns"
__email__ = "cmbruns@stanford.edu"
from unit import Unit, is_unit
from quantity import Quantity, is_quantity
from unit_math import *
from unit_definitions import *
from constants import *
"""
Physical quantities with units for dimensional analysis and automatic unit conversion.
"""
from __future__ import absolute_import
__docformat__ = "epytext en"
__author__ = "Christopher M. Bruns"
__copyright__ = "Copyright 2010, Stanford University and Christopher M. Bruns"
__credits__ = []
__license__ = "MIT"
__maintainer__ = "Christopher M. Bruns"
__email__ = "cmbruns@stanford.edu"
from .unit import Unit, is_unit
from .quantity import Quantity, is_quantity
from .unit_math import *
from .unit_definitions import *
from .constants import *
......@@ -807,8 +807,8 @@ def _is_string(x):
if isinstance(x, str):
return True
try:
first_item = iter(x).next()
inner_item = iter(first_item).next()
first_item = next(iter(x))
inner_item = next(iter(first_item))
if first_item is inner_item:
return True
else:
......
......@@ -102,10 +102,10 @@ class Unit(object):
# TODO - also handle non-simple units, i.e. units with multiple BaseUnits/ScaledUnits
assert len(self._top_base_units) == 1
assert len(self._scaled_units) == 0
dimension = self._top_base_units.iterkeys().next()
dimension = next(iter(self._top_base_units))
base_unit_dict = self._top_base_units[dimension]
assert len(base_unit_dict) == 1
parent_base_unit = base_unit_dict.iterkeys().next()
parent_base_unit = next(iter(base_unit_dict))
parent_exponent = base_unit_dict[parent_base_unit]
new_base_unit = BaseUnit(parent_base_unit.dimension, name, symbol)
# BaseUnit scale might be different depending on exponent
......@@ -121,10 +121,8 @@ class Unit(object):
Yields (BaseDimension, exponent) tuples comprising this unit.
"""
# There might be two units with the same dimension? No.
for dimension in sorted(self._all_base_units.iterkeys()):
exponent = 0
for base_unit in sorted(self._all_base_units[dimension].iterkeys()):
exponent += self._all_base_units[dimension][base_unit]
for dimension in sorted(self._all_base_units.keys()):
exponent = sum(self._all_base_units[dimension].values())
if exponent != 0:
yield (dimension, exponent)
......@@ -135,8 +133,8 @@ class Unit(object):
There might be multiple BaseUnits with the same dimension.
"""
for dimension in sorted(self._all_base_units.iterkeys()):
for base_unit in sorted(self._all_base_units[dimension].iterkeys()):
for dimension in sorted(self._all_base_units.keys()):
for base_unit in sorted(self._all_base_units[dimension].keys()):
exponent = self._all_base_units[dimension][base_unit]
yield (base_unit, exponent)
......@@ -144,8 +142,8 @@ class Unit(object):
"""
Yields (BaseUnit, exponent) tuples in this Unit, excluding those within BaseUnits.
"""
for dimension in sorted(self._top_base_units.iterkeys()):
for unit in sorted(self._top_base_units[dimension].iterkeys()):
for dimension in sorted(self._top_base_units.keys()):
for unit in sorted(self._top_base_units[dimension].keys()):
exponent = self._top_base_units[dimension][unit]
yield (unit, exponent)
......@@ -518,7 +516,7 @@ class ScaledUnit(object):
self.symbol = symbol
def __iter__(self):
for dim in sorted(self.base_units.iterkeys()):
for dim in sorted(self.base_units.keys()):
yield self.base_units[dim]
def iter_base_units(self):
......@@ -602,8 +600,7 @@ class UnitSystem(object):
if not len(self.base_units) == len(self.units):
raise ArithmeticError("UnitSystem must have same number of units as base dimensions")
# self.dimensions is a dict of {BaseDimension: index}
dimensions = base_units.keys()
dimensions.sort()
dimensions = sorted(base_units.keys())
self.dimensions = {}
for d in range(len(dimensions)):
self.dimensions[dimensions[d]] = d
......
......@@ -278,6 +278,7 @@ Parameters:
"""Get the list of Forces in this System"""
return [self.getForce(i) for i in range(self.getNumForces())]
%}
%newobject __copy__;
OpenMM::System* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::System>(*self);
}
......@@ -452,6 +453,7 @@ Parameters:
def __deepcopy__(self, memo):
return self.__copy__()
%}
%newobject __copy__;
OpenMM::Force* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::Force>(*self);
}
......@@ -470,6 +472,7 @@ Parameters:
def __deepcopy__(self, memo):
return self.__copy__()
%}
%newobject __copy__;
OpenMM::Integrator* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::Integrator>(*self);
}
......
......@@ -337,3 +337,32 @@ int Py_SequenceToVecDouble(PyObject* obj, std::vector<double>& out) {
// createCheckpoint returns a bytes object
$result = PyBytes_FromStringAndSize($1.c_str(), $1.length());
}
%typemap(in) std::string {
// if we have a C++ method that takes in a std::string, we're most happy
// to accept a python bytes object. But if the user passes in a unicode
// object we'll try to recover by encoding it to UTF-8 bytes
PyObject* temp = NULL;
char* c_str = NULL;
Py_ssize_t len = 0;
if (PyUnicode_Check($input)) {
temp = PyUnicode_AsUTF8String($input);
if (temp == NULL) {
SWIG_exception_fail(SWIG_TypeError, "'utf-8' codec can't decode byte");
}
PyBytes_AsStringAndSize(temp, &c_str, &len);
Py_XDECREF(temp);
} else if (PyBytes_Check($input)) {
PyBytes_AsStringAndSize($input, &c_str, &len);
} else {
SWIG_exception_fail(SWIG_TypeError, "argument must be str or bytes");
}
if (c_str == NULL) {
SWIG_exception_fail(SWIG_TypeError, "argument must be str or bytes");
}
$1 = std::string(c_str, len);
}
......@@ -4,14 +4,21 @@ from simtk.openmm.app import *
from simtk.openmm import *
from simtk.unit import *
import simtk.openmm.app.element as elem
try:
from scipy.io import netcdf
SCIPY_IMPORT_FAILED = False
except:
SCIPY_IMPORT_FAILED = True
def compareByElement(array1, array2, cmp):
for x, y in zip(array1, array2):
cmp(x, y)
class TestAmberInpcrdFile(unittest.TestCase):
"""Test the Amber inpcrd file parser"""
def test_CrdVelBox(self):
""" Test parsing ASCII restarts with crds, vels, and box """
cmp = self.assertAlmostEqual
......@@ -24,24 +31,21 @@ class TestAmberInpcrdFile(unittest.TestCase):
compareByElement(inpcrd.boxVectors[0].value_in_unit(angstroms),
[30.2642725, 0.0, 0.0], cmp)
@unittest.skipIf(SCIPY_IMPORT_FAILED, "Scipy is not installed")
def test_NetCDF(self):
""" Test NetCDF restart file parsing """
cmp = self.assertAlmostEqual
try:
from scipy.io import netcdf
except ImportError:
print('Not testing NetCDF file parser... scipy cannot be found')
else:
inpcrd = AmberInpcrdFile('systems/amber.ncrst')
self.assertEqual(len(inpcrd.positions), 2101)
compareByElement(inpcrd.positions[0].value_in_unit(angstroms),
[6.82122492718229, 6.6276250662042, -8.51668999892245],
cmp)
compareByElement(inpcrd.velocities[-1].value_in_unit(angstroms/picosecond),
[0.349702202733541*20.455, 0.391525333168534*20.455,
0.417941679767662*20.455], cmp)
self.assertAlmostEqual(inpcrd.boxVectors[0][0].value_in_unit(angstroms),
30.2642725, places=6)
inpcrd = AmberInpcrdFile('systems/amber.ncrst')
self.assertEqual(len(inpcrd.positions), 2101)
compareByElement(inpcrd.positions[0].value_in_unit(angstroms),
[6.82122492718229, 6.6276250662042, -8.51668999892245],
cmp)
compareByElement(inpcrd.velocities[-1].value_in_unit(angstroms/picosecond),
[0.349702202733541*20.455, 0.391525333168534*20.455,
0.417941679767662*20.455], cmp)
self.assertAlmostEqual(inpcrd.boxVectors[0][0].value_in_unit(angstroms),
30.2642725, places=6)
def test_CrdBox(self):
""" Test parsing ASCII restarts with only crds and box """
......
......@@ -23,23 +23,23 @@ class TestAmberPrmtopFile(unittest.TestCase):
def test_NonbondedMethod(self):
"""Test all five options for the nonbondedMethod parameter."""
methodMap = {NoCutoff:NonbondedForce.NoCutoff,
CutoffNonPeriodic:NonbondedForce.CutoffNonPeriodic,
CutoffPeriodic:NonbondedForce.CutoffPeriodic,
methodMap = {NoCutoff:NonbondedForce.NoCutoff,
CutoffNonPeriodic:NonbondedForce.CutoffNonPeriodic,
CutoffPeriodic:NonbondedForce.CutoffPeriodic,
Ewald:NonbondedForce.Ewald, PME: NonbondedForce.PME}
for method in methodMap:
system = prmtop1.createSystem(nonbondedMethod=method)
forces = system.getForces()
self.assertTrue(any(isinstance(f, NonbondedForce) and
f.getNonbondedMethod()==methodMap[method]
self.assertTrue(any(isinstance(f, NonbondedForce) and
f.getNonbondedMethod()==methodMap[method]
for f in forces))
def test_Cutoff(self):
"""Test to make sure the nonbondedCutoff parameter is passed correctly."""
for method in [CutoffNonPeriodic, CutoffPeriodic, Ewald, PME]:
system = prmtop1.createSystem(nonbondedMethod=method,
nonbondedCutoff=2*nanometer,
system = prmtop1.createSystem(nonbondedMethod=method,
nonbondedCutoff=2*nanometer,
constraints=HBonds)
cutoff_distance = 0.0*nanometer
cutoff_check = 2.0*nanometer
......@@ -52,8 +52,8 @@ class TestAmberPrmtopFile(unittest.TestCase):
"""Test to make sure the ewaldErrorTolerance parameter is passed correctly."""
for method in [Ewald, PME]:
system = prmtop1.createSystem(nonbondedMethod=method,
ewaldErrorTolerance=1e-6,
system = prmtop1.createSystem(nonbondedMethod=method,
ewaldErrorTolerance=1e-6,
constraints=HBonds)
tolerance = 0
tolerance_check = 1e-6
......@@ -76,13 +76,13 @@ class TestAmberPrmtopFile(unittest.TestCase):
topology = prmtop1.topology
for constraints_value in [None, HBonds, AllBonds, HAngles]:
for rigidWater_value in [True, False]:
system = prmtop1.createSystem(constraints=constraints_value,
system = prmtop1.createSystem(constraints=constraints_value,
rigidWater=rigidWater_value)
validateConstraints(self, topology, system,
validateConstraints(self, topology, system,
constraints_value, rigidWater_value)
def test_ImplicitSolvent(self):
"""Test the four types of implicit solvents using the implicitSolvent
"""Test the four types of implicit solvents using the implicitSolvent
parameter.
"""
......@@ -93,7 +93,7 @@ class TestAmberPrmtopFile(unittest.TestCase):
force_type = CustomGBForce
else:
force_type = GBSAOBCForce
self.assertTrue(any(isinstance(f, force_type) for f in forces))
def test_ImplicitSolventParameters(self):
......@@ -102,7 +102,7 @@ class TestAmberPrmtopFile(unittest.TestCase):
CutoffNonPeriodic:NonbondedForce.CutoffNonPeriodic}
for implicitSolvent_value in [HCT, OBC1, OBC2, GBn]:
for method in methodMap:
system = prmtop2.createSystem(implicitSolvent=implicitSolvent_value,
system = prmtop2.createSystem(implicitSolvent=implicitSolvent_value,
solventDielectric=50.0, soluteDielectric=0.9, nonbondedMethod=method)
if implicitSolvent_value in set([HCT, OBC1, GBn]):
for force in system.getForces():
......@@ -122,12 +122,12 @@ class TestAmberPrmtopFile(unittest.TestCase):
if isinstance(force, NonbondedForce):
self.assertEqual(force.getReactionFieldDielectric(), 1.0)
self.assertEqual(force.getNonbondedMethod(), methodMap[method])
self.assertTrue(found_matching_solvent_dielectric and
self.assertTrue(found_matching_solvent_dielectric and
found_matching_solute_dielectric)
def test_HydrogenMass(self):
"""Test that altering the mass of hydrogens works correctly."""
topology = prmtop1.topology
hydrogenMass = 4*amu
system1 = prmtop1.createSystem()
......@@ -279,7 +279,7 @@ class TestAmberPrmtopFile(unittest.TestCase):
def test_ImplicitSolventForces(self):
"""Compute forces for different implicit solvent types, and compare them to ones generated with a previous version of OpenMM to ensure they haven't changed."""
solventType = [HCT, OBC1, OBC2, GBn, GBn2]
nonbondedMethod = [NoCutoff, CutoffNonPeriodic, CutoffNonPeriodic, NoCutoff, NoCutoff]
salt = [0.0, 0.0, 0.5, 0.5, 0.0]*(moles/liter)
......@@ -288,7 +288,7 @@ class TestAmberPrmtopFile(unittest.TestCase):
for i in range(5):
system = prmtop2.createSystem(implicitSolvent=solventType[i], nonbondedMethod=nonbondedMethod[i], implicitSolventSaltConc=salt[i])
integrator = VerletIntegrator(0.001)
context = Context(system, integrator, Platform.getPlatformByName("CPU"))
context = Context(system, integrator, Platform.getPlatformByName("Reference"))
context.setPositions(pdb.positions)
state1 = context.getState(getForces=True)
state2 = XmlSerializer.deserialize(open('systems/alanine-dipeptide-implicit-forces/'+file[i]+'.xml').read())
......@@ -307,7 +307,7 @@ class TestAmberPrmtopFile(unittest.TestCase):
system.addForce(MonteCarloBarostat(1.0 * atmospheres, temperature, 1))
integrator = LangevinIntegrator(temperature, 1.0 / picosecond, 0.0001 * picoseconds)
simulation = Simulation(prmtop.topology, system, integrator)
simulation.context.setPositions(inpcrd.positions)
simulation.context.setPeriodicBoxVectors(*inpcrd.boxVectors)
......@@ -315,6 +315,7 @@ class TestAmberPrmtopFile(unittest.TestCase):
fname = tempfile.mktemp(suffix='.dcd')
simulation.reporters.append(DCDReporter(fname, 1)) # This is an explicit test for the bugs in issue #850
simulation.step(5)
del simulation
os.remove(fname)
def testChamber(self):
......
......@@ -8,7 +8,7 @@ import simtk.openmm.app.element as elem
class TestCharmmFiles(unittest.TestCase):
"""Test the GromacsTopFile.createSystem() method."""
def setUp(self):
"""Set up the tests by loading the input files."""
......@@ -23,14 +23,14 @@ class TestCharmmFiles(unittest.TestCase):
def test_NonbondedMethod(self):
"""Test both non-periodic methods for the systems"""
methodMap = {NoCutoff:NonbondedForce.NoCutoff,
methodMap = {NoCutoff:NonbondedForce.NoCutoff,
CutoffNonPeriodic:NonbondedForce.CutoffNonPeriodic}
for top in (self.psf_c, self.psf_x, self.psf_v):
for method in methodMap:
system = top.createSystem(self.params, nonbondedMethod=method)
forces = system.getForces()
self.assertTrue(any(isinstance(f, NonbondedForce) and
f.getNonbondedMethod()==methodMap[method]
self.assertTrue(any(isinstance(f, NonbondedForce) and
f.getNonbondedMethod()==methodMap[method]
for f in forces))
def test_Cutoff(self):
......@@ -39,7 +39,7 @@ class TestCharmmFiles(unittest.TestCase):
for top in (self.psf_c, self.psf_x, self.psf_v):
for method in [CutoffNonPeriodic]:
system = top.createSystem(self.params, nonbondedMethod=method,
nonbondedCutoff=2*nanometer,
nonbondedCutoff=2*nanometer,
constraints=HBonds)
cutoff_distance = 0.0*nanometer
cutoff_check = 2.0*nanometer
......@@ -67,7 +67,7 @@ class TestCharmmFiles(unittest.TestCase):
"""
system = self.psf_x.createSystem(self.params, implicitSolvent=GBn,
solventDielectric=50.0,
solventDielectric=50.0,
soluteDielectric = 0.9)
for force in system.getForces():
if isinstance(force, NonbondedForce):
......@@ -75,7 +75,7 @@ class TestCharmmFiles(unittest.TestCase):
def test_HydrogenMass(self):
"""Test that altering the mass of hydrogens works correctly."""
topology = self.psf_v.topology
hydrogenMass = 4*amu
system1 = self.psf_v.createSystem(self.params)
......@@ -131,7 +131,7 @@ class TestCharmmFiles(unittest.TestCase):
for i in range(5):
system = self.psf_c.createSystem(self.params, implicitSolvent=solventType[i], nonbondedMethod=nonbondedMethod[i], implicitSolventSaltConc=salt[i])
integrator = VerletIntegrator(0.001)
context = Context(system, integrator, Platform.getPlatformByName("CPU"))
context = Context(system, integrator, Platform.getPlatformByName("Reference"))
context.setPositions(self.pdb.positions)
state1 = context.getState(getForces=True)
#out = open('systems/ala-ala-ala-implicit-forces/'+file[i]+'.xml', 'w')
......
import os
import unittest
import tempfile
import numpy as np
from simtk.openmm import app
import simtk.openmm as mm
from simtk import unit
......@@ -18,21 +18,24 @@ class TestCheckpointReporter(unittest.TestCase):
self.simulation.context.setPositions(pdb.positions)
def test_1(self):
file = tempfile.NamedTemporaryFile()
file = tempfile.NamedTemporaryFile(delete=False)
self.simulation.reporters.append(app.CheckpointReporter(file, 1))
self.simulation.step(1)
# get the current positions
positions = self.simulation.context.getState(getPositions=True).getPositions(asNumpy=True)._value
positions = self.simulation.context.getState(getPositions=True).getPositions()
# now set the positions into junk...
self.simulation.context.setPositions(np.random.random(positions.shape))
self.simulation.context.setPositions([mm.Vec3(0, 0, 0)] * len(positions))
# then reload the right positions from the checkpoint
file.close()
with open(file.name, 'rb') as f:
self.simulation.context.loadCheckpoint(f.read())
file.close()
os.unlink(file.name)
newPositions = self.simulation.context.getState(getPositions=True).getPositions(asNumpy=True)._value
np.testing.assert_array_equal(positions, newPositions)
newPositions = self.simulation.context.getState(getPositions=True).getPositions()
self.assertSequenceEqual(positions, newPositions)
if __name__ == '__main__':
unittest.main()
......@@ -3,12 +3,16 @@ from validateConstraints import *
from simtk.openmm.app import *
from simtk.openmm import *
from simtk.unit import *
from simtk.openmm.app.gromacstopfile import _defaultGromacsIncludeDir
import simtk.openmm.app.element as elem
GROMACS_INCLUDE = _defaultGromacsIncludeDir()
@unittest.skipIf(not os.path.exists(GROMACS_INCLUDE), 'GROMACS is not installed')
class TestGromacsTopFile(unittest.TestCase):
"""Test the GromacsTopFile.createSystem() method."""
def setUp(self):
"""Set up the tests by loading the input files."""
......@@ -20,15 +24,15 @@ class TestGromacsTopFile(unittest.TestCase):
def test_NonbondedMethod(self):
"""Test all five options for the nonbondedMethod parameter."""
methodMap = {NoCutoff:NonbondedForce.NoCutoff,
CutoffNonPeriodic:NonbondedForce.CutoffNonPeriodic,
CutoffPeriodic:NonbondedForce.CutoffPeriodic,
methodMap = {NoCutoff:NonbondedForce.NoCutoff,
CutoffNonPeriodic:NonbondedForce.CutoffNonPeriodic,
CutoffPeriodic:NonbondedForce.CutoffPeriodic,
Ewald:NonbondedForce.Ewald, PME: NonbondedForce.PME}
for method in methodMap:
system = self.top1.createSystem(nonbondedMethod=method)
forces = system.getForces()
self.assertTrue(any(isinstance(f, NonbondedForce) and
f.getNonbondedMethod()==methodMap[method]
self.assertTrue(any(isinstance(f, NonbondedForce) and
f.getNonbondedMethod()==methodMap[method]
for f in forces))
def test_ff99SBILDN(self):
......@@ -49,8 +53,8 @@ class TestGromacsTopFile(unittest.TestCase):
"""Test to make sure the nonbondedCutoff parameter is passed correctly."""
for method in [CutoffNonPeriodic, CutoffPeriodic, Ewald, PME]:
system = self.top1.createSystem(nonbondedMethod=method,
nonbondedCutoff=2*nanometer,
system = self.top1.createSystem(nonbondedMethod=method,
nonbondedCutoff=2*nanometer,
constraints=HBonds)
cutoff_distance = 0.0*nanometer
cutoff_check = 2.0*nanometer
......@@ -63,8 +67,8 @@ class TestGromacsTopFile(unittest.TestCase):
"""Test to make sure the ewaldErrorTolerance parameter is passed correctly."""
for method in [Ewald, PME]:
system = self.top1.createSystem(nonbondedMethod=method,
ewaldErrorTolerance=1e-6,
system = self.top1.createSystem(nonbondedMethod=method,
ewaldErrorTolerance=1e-6,
constraints=HBonds)
tolerance = 0
tolerance_check = 1e-6
......@@ -86,9 +90,9 @@ class TestGromacsTopFile(unittest.TestCase):
topology = self.top1.topology
for constraints_value in [None, HBonds, AllBonds, HAngles]:
for rigidWater_value in [True, False]:
system = self.top1.createSystem(constraints=constraints_value,
system = self.top1.createSystem(constraints=constraints_value,
rigidWater=rigidWater_value)
validateConstraints(self, topology, system,
validateConstraints(self, topology, system,
constraints_value, rigidWater_value)
def test_ImplicitSolvent(self):
......@@ -102,8 +106,8 @@ class TestGromacsTopFile(unittest.TestCase):
"""Test that solventDielectric and soluteDielectric are passed correctly.
"""
system = self.top2.createSystem(implicitSolvent=OBC2,
solventDielectric=50.0,
system = self.top2.createSystem(implicitSolvent=OBC2,
solventDielectric=50.0,
soluteDielectric = 0.9)
found_matching_solvent_dielectric=False
found_matching_solute_dielectric=False
......@@ -115,12 +119,12 @@ class TestGromacsTopFile(unittest.TestCase):
found_matching_solute_dielectric = True
if isinstance(force, NonbondedForce):
self.assertEqual(force.getReactionFieldDielectric(), 1.0)
self.assertTrue(found_matching_solvent_dielectric and
self.assertTrue(found_matching_solvent_dielectric and
found_matching_solute_dielectric)
def test_HydrogenMass(self):
"""Test that altering the mass of hydrogens works correctly."""
topology = self.top1.topology
hydrogenMass = 4*amu
system1 = self.top1.createSystem()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment