Commit 6a50de8a authored by Robert McGibbon's avatar Robert McGibbon
Browse files

Fix force group 31

parent 596a4197
%inline %{
typedef int bitmask32t;
%}
%typemap(in) bitmask32t %{
$1 = 0;
#if PY_VERSION_HEX >= 0x03000000
if (PyLong_Check($input)) {
unsigned long u = PyLong_AsUnsignedLongMask($input);
#else
if (PyInt_Check($input)) {
unsigned long u = PyInt_AsUnsignedLongMask($input);
#endif
// 64-bit Windows has 32-bit longs, but other platforms have
// 64-bit longs
$1 = u & 0xffffffff;
} else {
PyErr_SetString(PyExc_ValueError, "in method $symname, argument $argnum could not be converted to type $type");
SWIG_fail;
}
%}
%extend OpenMM::Context { %extend OpenMM::Context {
PyObject *_getStateAsLists(int getPositions, PyObject *_getStateAsLists(int getPositions,
int getVelocities, int getVelocities,
...@@ -5,7 +29,7 @@ ...@@ -5,7 +29,7 @@
int getEnergy, int getEnergy,
int getParameters, int getParameters,
int enforcePeriodic, int enforcePeriodic,
int groups) { bitmask32t groups) {
State state; State state;
PyThreadState* _savePythonThreadState = PyEval_SaveThread(); PyThreadState* _savePythonThreadState = PyEval_SaveThread();
int types = 0; int types = 0;
...@@ -36,15 +60,9 @@ ...@@ -36,15 +60,9 @@
enforcePeriodicBox=False, enforcePeriodicBox=False,
groups=-1): groups=-1):
""" """
getState(self, getState(self, getPositions=False, getVelocities=False, getForces=False,
getPositions = False, getEnergy=False, getParameters=False, enforcePeriodicBox=False,
getVelocities = False, groups=-1) -> State
getForces = False,
getEnergy = False,
getParameters = False,
enforcePeriodicBox = False,
groups = -1)
-> State
Get a State object recording the current state information stored in this context. Get a State object recording the current state information stored in this context.
......
...@@ -20,7 +20,7 @@ class TestForceGroups(unittest.TestCase): ...@@ -20,7 +20,7 @@ class TestForceGroups(unittest.TestCase):
self.context = context self.context = context
def test1(self): def test1(self):
n = 31 # Should be 32, but github issue #1198 n = 32
for (i,j) in itertools.combinations(range(n), 2): for (i,j) in itertools.combinations(range(n), 2):
groups = 1<<i | 1<<j groups = 1<<i | 1<<j
e_0 = self.context.getState(getEnergy=True, groups=groups).getPotentialEnergy()._value e_0 = self.context.getState(getEnergy=True, groups=groups).getPotentialEnergy()._value
...@@ -34,6 +34,12 @@ class TestForceGroups(unittest.TestCase): ...@@ -34,6 +34,12 @@ class TestForceGroups(unittest.TestCase):
# groups must be an int or set # groups must be an int or set
self.context.getState(getEnergy=True, groups=(1, 2)) self.context.getState(getEnergy=True, groups=(1, 2))
def test3(self):
e_0 = self.context.getState(getEnergy=True, groups=-1).getPotentialEnergy()._value
e_ref = sum(range(32))
self.assertEqual(e_0, e_ref)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment