"platforms/vscode:/vscode.git/clone" did not exist on "5f374e1da59d7325e2aed9683fc672239bf8f34d"
Commit ecc7e011 authored by Robert McGibbon's avatar Robert McGibbon
Browse files

Modify Context.getState to accept a set of indices in addition to a bitmask

parent 8eaf3c9c
......@@ -45,9 +45,9 @@
enforcePeriodicBox = False,
groups = -1)
-> State
Get a State object recording the current state information stored in this context.
Parameters:
- getPositions (bool=False) whether to store particle positions in the State
- getVelocities (bool=False) whether to store particle velocities in the State
......@@ -55,26 +55,31 @@
- getEnergy (bool=False) whether to store potential and kinetic energy in the State
- getParameter (bool=False) whether to store context parameters in the State
- enforcePeriodicBox (bool=False) if false, the position of each particle will be whatever position is stored in the Context, regardless of periodic boundary conditions. If true, particle positions will be translated so the center of every molecule lies in the same periodic box.
- groups (int=-1) a set of bit flags for which force groups to include when computing forces and energies. Group i will be included if (groups&(1<<i)) != 0. The default value includes all groups.
- groups (set={0,1,2,...,31}) a set of indices for which force groups
to include when computing forces and energies. The default value
includes all groups. groups can also be passed as an unsigned integer
interpreted as a bitmask, in which case group i will be included if
(groups&(1<<i)) != 0.
"""
if getPositions: getP=1
else: getP=0
if getVelocities: getV=1
else: getV=0
if getForces: getF=1
else: getF=0
if getEnergy: getE=1
else: getE=0
if getParameters: getPa=1
else: getPa=0
if enforcePeriodicBox: enforcePeriodic=1
else: enforcePeriodic=0
getP, getV, getF, getE, getPa, enforcePeriodic = map(bool,
(getPositions, getVelocities, getForces, getEnergy, getParameters,
enforcePeriodicBox))
try:
# is the input integer-like?
groups_mask = int(groups)
except TypeError:
if isinstance(groups, set):
# nope, okay, then it should be an set
groups_mask = functools.reduce(operator.or_,
((1<<x) & 0xffffffff for x in groups))
else:
raise TypeError('%s is neither an int nor set' % groups)
(simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = \
self._getStateAsLists(getP, getV, getF, getE, getPa, enforcePeriodic, groups)
self._getStateAsLists(getP, getV, getF, getE, getPa, enforcePeriodic, groups_mask)
state = State(simTime=simTime,
energy=energy,
coordList=coordList,
......@@ -83,11 +88,11 @@
periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap)
return state
def setState(self, state):
"""
setState(Context self, State state)
Copy information from a State object into this Context. This restores the Context to
approximately the same state it was in when the State was created. If the State does not include
a piece of information (e.g. positions or velocities), that aspect of the Context is
......@@ -108,7 +113,7 @@
for param in state._paramMap:
self.setParameter(param, state._paramMap[param])
%}
%feature("docstring") createCheckpoint "Create a checkpoint recording the current state of the Context.
This should be treated as an opaque block of binary data. See loadCheckpoint() for more details.
......@@ -196,9 +201,9 @@ Parameters:
enforcePeriodicBox = False,
groups = -1)
-> State
Get a State object recording the current state information about one copy of the system.
Parameters:
- copy (int) the index of the copy for which to retrieve state information
- getPositions (bool=False) whether to store particle positions in the State
......@@ -207,26 +212,30 @@ Parameters:
- getEnergy (bool=False) whether to store potential and kinetic energy in the State
- getParameter (bool=False) whether to store context parameters in the State
- enforcePeriodicBox (bool=False) if false, the position of each particle will be whatever position is stored in the Context, regardless of periodic boundary conditions. If true, particle positions will be translated so the center of every molecule lies in the same periodic box.
- groups (int=-1) a set of bit flags for which force groups to include when computing forces and energies. Group i will be included if (groups&(1<<i)) != 0. The default value includes all groups.
- groups (set={0,1,2,...,31}) a set of indices for which force groups
to include when computing forces and energies. The default value
includes all groups. groups can also be passed as an unsigned integer
interpreted as a bitmask, in which case group i will be included if
(groups&(1<<i)) != 0.
"""
if getPositions: getP=1
else: getP=0
if getVelocities: getV=1
else: getV=0
if getForces: getF=1
else: getF=0
if getEnergy: getE=1
else: getE=0
if getParameters: getPa=1
else: getPa=0
if enforcePeriodicBox: enforcePeriodic=1
else: enforcePeriodic=0
getP, getV, getF, getE, getPa, enforcePeriodic = map(bool,
(getPositions, getVelocities, getForces, getEnergy, getParameters,
enforcePeriodicBox))
try:
# is the input integer-like?
groups_mask = int(groups)
except TypeError:
if isinstance(groups, set):
groups_mask = functools.reduce(operator.or_,
((1<<x) & 0xffffffff for x in groups))
else:
raise TypeError('%s is neither an int nor set' % groups)
(simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = \
self._getStateAsLists(copy, getP, getV, getF, getE, getPa, enforcePeriodic, groups)
forceList, paramMap) = \
self._getStateAsLists(copy, getP, getV, getF, getE, getPa, enforcePeriodic, groups_mask)
state = State(simTime=simTime,
energy=energy,
coordList=coordList,
......@@ -299,7 +308,7 @@ Parameters:
ss << inputString;
return OpenMM::XmlSerializer::deserialize<OpenMM::System>(ss);
}
static std::string _serializeForce(const OpenMM::Force* object) {
std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::Force>(object, "Force", ss);
......@@ -312,7 +321,7 @@ Parameters:
ss << inputString;
return OpenMM::XmlSerializer::deserialize<OpenMM::Force>(ss);
}
static std::string _serializeIntegrator(const OpenMM::Integrator* object) {
std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::Integrator>(object, "Integrator", ss);
......@@ -327,8 +336,8 @@ Parameters:
}
static std::string _serializeStateAsLists(
const std::vector<Vec3>& pos,
const std::vector<Vec3>& vel,
const std::vector<Vec3>& pos,
const std::vector<Vec3>& vel,
const std::vector<Vec3>& forces,
double kineticEnergy,
double potentialEnergy,
......@@ -341,7 +350,7 @@ Parameters:
OpenMM::XmlSerializer::serialize<OpenMM::State>(&myState, "State", buffer);
return buffer.str();
}
static PyObject* _deserializeStringIntoLists(const std::string &stateAsString) {
std::stringstream ss;
ss << stateAsString;
......@@ -369,7 +378,7 @@ Parameters:
try:
velocities = pythonState.getVelocities().value_in_unit(unit.nanometers/unit.picoseconds)
types |= 2
except:
except:
pass
try:
forces = pythonState.getForces().value_in_unit(unit.kilojoules_per_mole/unit.nanometers)
......@@ -390,14 +399,14 @@ Parameters:
time = pythonState.getTime().value_in_unit(unit.picoseconds)
boxVectors = pythonState.getPeriodicBoxVectors().value_in_unit(unit.nanometers)
string = XmlSerializer._serializeStateAsLists(positions, velocities, forces, kineticEnergy, potentialEnergy, time, boxVectors, params, types)
return string
return string
@staticmethod
def _deserializeState(pythonString):
(simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = XmlSerializer._deserializeStringIntoLists(pythonString)
state = State(simTime=simTime,
energy=energy,
coordList=coordList,
......
......@@ -8,6 +8,8 @@ except ImportError:
import copy
import sys
import math
import functools
import operator
RMIN_PER_SIGMA=math.pow(2, 1/6.0)
RVDW_PER_SIGMA=math.pow(2, 1/6.0)/2.0
if sys.version_info[0] == 2:
......
import unittest
import itertools
import simtk.openmm as mm
class TestForceGroups(unittest.TestCase):
def setUp(self):
system = mm.System()
system.addParticle(1.0)
for i in range(32):
force = mm.CustomExternalForce(str(i))
force.addParticle(0, [])
force.setForceGroup(i)
system.addForce(force)
platform = mm.Platform.getPlatformByName('Reference')
context = mm.Context(system, mm.VerletIntegrator(0), platform)
context.setPositions([(0,0,0)])
self.context = context
def test1(self):
n = 31 # Should be 32, but github issue #1198
for (i,j) in itertools.combinations(range(n), 2):
groups = 1<<i | 1<<j
e_0 = self.context.getState(getEnergy=True, groups=groups).getPotentialEnergy()._value
e_1 = self.context.getState(getEnergy=True, groups={i,j}).getPotentialEnergy()._value
e_ref = i+j
self.assertEqual(e_0, e_ref)
self.assertEqual(e_1, e_ref)
def test2(self):
with self.assertRaises(TypeError):
# groups must be an int or set
self.context.getState(getEnergy=True, groups=(1, 2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment