"...ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "62cc2344329abfa3a89c10e0e33a642ac51468d4"
Commit b84e22ba authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1746 from peastman/state

Improved performance of Python State objects
parents 53f770f4 ac7953db
...@@ -10,6 +10,7 @@ install: ...@@ -10,6 +10,7 @@ install:
- "set PATH=C:\\Python35-x64;C:\\Python35-x64\\Scripts;%PATH%" - "set PATH=C:\\Python35-x64;C:\\Python35-x64\\Scripts;%PATH%"
- "set PATH=C:\\Program Files (x86)\\Git\\bin;%PATH%" - "set PATH=C:\\Program Files (x86)\\Git\\bin;%PATH%"
- pip install pytest - pip install pytest
- pip install numpy
# Use cclash for compiler caching (experimental) # Use cclash for compiler caching (experimental)
- ps: wget https://github.com/inorton/cclash/releases/download/0.3.14/cclash-0.3.14.zip -OutFile cclash-0.3.14.zip - ps: wget https://github.com/inorton/cclash/releases/download/0.3.14/cclash-0.3.14.zip -OutFile cclash-0.3.14.zip
......
...@@ -6,6 +6,7 @@ import re ...@@ -6,6 +6,7 @@ import re
import os import os
import sys import sys
import platform import platform
import numpy
from distutils.core import setup from distutils.core import setup
MAJOR_VERSION_NUM='@OPENMM_MAJOR_VERSION@' MAJOR_VERSION_NUM='@OPENMM_MAJOR_VERSION@'
...@@ -203,6 +204,7 @@ def buildKeywordDictionary(major_version_num=MAJOR_VERSION_NUM, ...@@ -203,6 +204,7 @@ def buildKeywordDictionary(major_version_num=MAJOR_VERSION_NUM,
library_dirs=[openmm_lib_path] library_dirs=[openmm_lib_path]
include_dirs=openmm_include_path.split(';') include_dirs=openmm_include_path.split(';')
include_dirs.append(numpy.get_include())
extensionArgs = {"name": "simtk.openmm._openmm", extensionArgs = {"name": "simtk.openmm._openmm",
"sources": ["src/swig_doxygen/OpenMMSwig.cxx"], "sources": ["src/swig_doxygen/OpenMMSwig.cxx"],
......
...@@ -1232,13 +1232,13 @@ ENABLE_PREPROCESSING = YES ...@@ -1232,13 +1232,13 @@ ENABLE_PREPROCESSING = YES
# compilation will be performed. Macro expansion can be done in a controlled # compilation will be performed. Macro expansion can be done in a controlled
# way by setting EXPAND_ONLY_PREDEF to YES. # way by setting EXPAND_ONLY_PREDEF to YES.
MACRO_EXPANSION = NO MACRO_EXPANSION = YES
# If the EXPAND_ONLY_PREDEF and MACRO_EXPANSION tags are both set to YES # If the EXPAND_ONLY_PREDEF and MACRO_EXPANSION tags are both set to YES
# then the macro expansion is limited to the macros specified with the # then the macro expansion is limited to the macros specified with the
# PREDEFINED and EXPAND_AS_DEFINED tags. # PREDEFINED and EXPAND_AS_DEFINED tags.
EXPAND_ONLY_PREDEF = NO EXPAND_ONLY_PREDEF = YES
# If the SEARCH_INCLUDES tag is set to YES (the default) the includes files # If the SEARCH_INCLUDES tag is set to YES (the default) the includes files
# in the INCLUDE_PATH (see below) will be search if a #include is found. # in the INCLUDE_PATH (see below) will be search if a #include is found.
...@@ -1266,7 +1266,7 @@ INCLUDE_FILE_PATTERNS = ...@@ -1266,7 +1266,7 @@ INCLUDE_FILE_PATTERNS =
# undefined via #undef or recursively expanded use the := operator # undefined via #undef or recursively expanded use the := operator
# instead of the = operator. # instead of the = operator.
PREDEFINED = PREDEFINED = OPENMM_EXPORT=
# If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then # If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then
# this tag can be used to specify a list of macro names that should be expanded. # this tag can be used to specify a list of macro names that should be expanded.
......
...@@ -13,7 +13,10 @@ DOC_STRINGS = {("Context", "setPositions") : ...@@ -13,7 +13,10 @@ DOC_STRINGS = {("Context", "setPositions") :
# Do not generate wrappers for the following methods. # Do not generate wrappers for the following methods.
# Indexed by (className, [methodName [, numParams]]) # Indexed by (className, [methodName [, numParams]])
SKIP_METHODS = [('State',), SKIP_METHODS = [('State', 'getPositions'),
('State', 'getVelocities'),
('State', 'getForces'),
('StateBuilder',),
('Vec3',), ('Vec3',),
('AngleInfo',), ('AngleInfo',),
('ApplyAndersenThermostatKernel',), ('ApplyAndersenThermostatKernel',),
...@@ -87,8 +90,6 @@ SKIP_METHODS = [('State',), ...@@ -87,8 +90,6 @@ SKIP_METHODS = [('State',),
('UpdateTimeKernel',), ('UpdateTimeKernel',),
('VdwInfo',), ('VdwInfo',),
('WcaDispersionInfo',), ('WcaDispersionInfo',),
('Context', 'getState'),
('Context', 'setState'),
('Context', 'createCheckpoint'), ('Context', 'createCheckpoint'),
('Context', 'loadCheckpoint'), ('Context', 'loadCheckpoint'),
('CudaPlatform',), ('CudaPlatform',),
...@@ -102,7 +103,6 @@ SKIP_METHODS = [('State',), ...@@ -102,7 +103,6 @@ SKIP_METHODS = [('State',),
('Platform', 'createKernel'), ('Platform', 'createKernel'),
('Platform', 'registerKernelFactory'), ('Platform', 'registerKernelFactory'),
('IntegrateRPMDStepKernel',), ('IntegrateRPMDStepKernel',),
('RPMDIntegrator', 'getState'),
('CalcDrudeForceKernel',), ('CalcDrudeForceKernel',),
('IntegrateDrudeLangevinStepKernel',), ('IntegrateDrudeLangevinStepKernel',),
('IntegrateDrudeSCFStepKernel',), ('IntegrateDrudeSCFStepKernel',),
...@@ -427,6 +427,13 @@ UNITS = { ...@@ -427,6 +427,13 @@ UNITS = {
: (None, (None, None, None, None, : (None, (None, None, None, None,
'unit.kilojoules_per_mole', 'unit.kilojoules_per_mole', 'unit.kilojoules_per_mole', 'unit.kilojoules_per_mole', 'unit.kilojoules_per_mole', 'unit.kilojoules_per_mole',
'unit.kilojoules_per_mole', 'unit.kilojoules_per_mole', 'unit.kilojoules_per_mole')), 'unit.kilojoules_per_mole', 'unit.kilojoules_per_mole', 'unit.kilojoules_per_mole')),
("State", "getTime") : ('unit.picosecond', ()),
("State", "getKineticEnergy") : ('unit.kilojoules_per_mole', ()),
("State", "getPotentialEnergy") : ('unit.kilojoules_per_mole', ()),
("State", "getPeriodicBoxVolume") : ('unit.nanometers**3', ()),
("State", "getPeriodicBoxVectors") : ('unit.nanometers', ()),
("State", "getParameters") : (None, ()),
("State", "getEnergyParameterDerivatives") : (None, ()),
("System", "getConstraintParameters") : (None, (None, None, 'unit.nanometer')), ("System", "getConstraintParameters") : (None, (None, None, 'unit.nanometer')),
("System", "getForce") : (None, ()), ("System", "getForce") : (None, ()),
("System", "getVirtualSite") : (None, ()), ("System", "getVirtualSite") : (None, ()),
......
%inline %{ %inline %{
typedef int bitmask32t; #include <cstring>
#include <numpy/arrayobject.h>
%} %}
%typemap(in) bitmask32t %{
$1 = 0;
#if PY_VERSION_HEX >= 0x03000000
if (PyLong_Check($input)) {
unsigned long u = PyLong_AsUnsignedLongMask($input);
#else
if (PyInt_Check($input)) {
unsigned long u = PyInt_AsUnsignedLongMask($input);
#endif
// 64-bit Windows has 32-bit longs, but other platforms have
// 64-bit longs
$1 = u & 0xffffffff;
} else {
PyErr_SetString(PyExc_ValueError, "in method $symname, argument $argnum could not be converted to type $type");
SWIG_fail;
}
%}
%extend OpenMM::Context { %extend OpenMM::Context {
PyObject *_getStateAsLists(int getPositions,
int getVelocities,
int getForces,
int getEnergy,
int getParameters,
int getParameterDerivatives,
int enforcePeriodic,
bitmask32t groups) {
State state;
PyThreadState* _savePythonThreadState = PyEval_SaveThread();
int types = 0;
if (getPositions) types |= State::Positions;
if (getVelocities) types |= State::Velocities;
if (getForces) types |= State::Forces;
if (getEnergy) types |= State::Energy;
if (getParameters) types |= State::Parameters;
if (getParameterDerivatives) types |= State::ParameterDerivatives;
try {
state = self->getState(types, enforcePeriodic, groups);
}
catch (...) {
PyEval_RestoreThread(_savePythonThreadState);
throw;
}
PyEval_RestoreThread(_savePythonThreadState);
return _convertStateToLists(state);
}
%pythoncode %{ %pythoncode %{
def getState(self, getPositions=False, getVelocities=False, def getState(self, getPositions=False, getVelocities=False,
...@@ -83,10 +36,6 @@ ...@@ -83,10 +36,6 @@
can also be passed as an unsigned integer interpreted as a bitmask, can also be passed as an unsigned integer interpreted as a bitmask,
in which case group i will be included if (groups&(1<<i)) != 0. in which case group i will be included if (groups&(1<<i)) != 0.
""" """
getP, getV, getF, getE, getPa, getPd, enforcePeriodic = map(bool,
(getPositions, getVelocities, getForces, getEnergy, getParameters,
getParameterDerivatives, enforcePeriodicBox))
try: try:
# is the input integer-like? # is the input integer-like?
groups_mask = int(groups) groups_mask = int(groups)
...@@ -97,44 +46,24 @@ ...@@ -97,44 +46,24 @@
((1<<x) & 0xffffffff for x in groups)) ((1<<x) & 0xffffffff for x in groups))
else: else:
raise TypeError('%s is neither an int nor set' % groups) raise TypeError('%s is neither an int nor set' % groups)
if groups_mask > 0x80000000:
(simTime, periodicBoxVectorsList, energy, coordList, velList, groups_mask -= 0x100000000
forceList, paramMap, paramDerivMap) = \ types = 0
self._getStateAsLists(getP, getV, getF, getE, getPa, getPd, enforcePeriodic, groups_mask) if getPositions:
types += State.Positions
state = State(simTime=simTime, if getVelocities:
energy=energy, types += State.Velocities
coordList=coordList, if getForces:
velList=velList, types += State.Forces
forceList=forceList, if getEnergy:
periodicBoxVectorsList=periodicBoxVectorsList, types += State.Energy
paramMap=paramMap, if getParameters:
paramDerivMap=paramDerivMap) types += State.Parameters
if getParameterDerivatives:
types += State.ParameterDerivatives
state = _openmm.Context_getState(self, types, enforcePeriodicBox, groups_mask)
return state return state
def setState(self, state):
"""
setState(Context self, State state)
Copy information from a State object into this Context. This restores the Context to
approximately the same state it was in when the State was created. If the State does not include
a piece of information (e.g. positions or velocities), that aspect of the Context is
left unchanged.
Even when all possible information is included in the State, the effect of calling this method
is still less complete than loadCheckpoint(). For example, it does not restore the internal
states of random number generators. On the other hand, it has the advantage of not being hardware
specific.
"""
self.setTime(state._simTime)
self.setPeriodicBoxVectors(state._periodicBoxVectorsList[0], state._periodicBoxVectorsList[1], state._periodicBoxVectorsList[2])
if state._coordList is not None:
self.setPositions(state._coordList)
if state._velList is not None:
self.setVelocities(state._velList)
if state._paramMap is not None:
for param in state._paramMap:
self.setParameter(param, state._paramMap[param])
%} %}
%feature("docstring") createCheckpoint "Create a checkpoint recording the current state of the Context. %feature("docstring") createCheckpoint "Create a checkpoint recording the current state of the Context.
...@@ -175,36 +104,6 @@ Parameters: ...@@ -175,36 +104,6 @@ Parameters:
} }
%extend OpenMM::RPMDIntegrator { %extend OpenMM::RPMDIntegrator {
PyObject *_getStateAsLists(int copy,
int getPositions,
int getVelocities,
int getForces,
int getEnergy,
int getParameters,
int getParameterDerivatives,
int enforcePeriodic,
int groups) {
State state;
PyThreadState* _savePythonThreadState = PyEval_SaveThread();
int types = 0;
if (getPositions) types |= State::Positions;
if (getVelocities) types |= State::Velocities;
if (getForces) types |= State::Forces;
if (getEnergy) types |= State::Energy;
if (getParameters) types |= State::Parameters;
if (getParameterDerivatives) types |= State::ParameterDerivatives;
try {
state = self->getState(copy, types, enforcePeriodic, groups);
}
catch (...) {
PyEval_RestoreThread(_savePythonThreadState);
throw;
}
PyEval_RestoreThread(_savePythonThreadState);
return _convertStateToLists(state);
}
%pythoncode %{ %pythoncode %{
def getState(self, def getState(self,
copy, copy,
...@@ -258,19 +157,22 @@ Parameters: ...@@ -258,19 +157,22 @@ Parameters:
((1<<x) & 0xffffffff for x in groups)) ((1<<x) & 0xffffffff for x in groups))
else: else:
raise TypeError('%s is neither an int nor set' % groups) raise TypeError('%s is neither an int nor set' % groups)
if groups_mask > 0x80000000:
(simTime, periodicBoxVectorsList, energy, coordList, velList, groups_mask -= 0x100000000
forceList, paramMap, paramDerivMap) = \ types = 0
self._getStateAsLists(getP, getV, getF, getE, getPa, getPd, enforcePeriodic, groups_mask) if getPositions:
types += State.Positions
state = State(simTime=simTime, if getVelocities:
energy=energy, types += State.Velocities
coordList=coordList, if getForces:
velList=velList, types += State.Forces
forceList=forceList, if getEnergy:
periodicBoxVectorsList=periodicBoxVectorsList, types += State.Energy
paramMap=paramMap, if getParameters:
paramDerivMap=paramDerivMap) types += State.Parameters
if getParameterDerivatives:
types += State.ParameterDerivatives
state = _openmm.RPMDIntegrator_getState(self, copy, types, enforcePeriodicBox, groups_mask)
return state return state
%} %}
} }
...@@ -376,95 +278,20 @@ Parameters: ...@@ -376,95 +278,20 @@ Parameters:
return OpenMM::XmlSerializer::deserialize<OpenMM::TabulatedFunction>(ss); return OpenMM::XmlSerializer::deserialize<OpenMM::TabulatedFunction>(ss);
} }
static std::string _serializeStateAsLists( static std::string _serializeState(const OpenMM::State* object) {
const std::vector<Vec3>& pos, std::stringstream ss;
const std::vector<Vec3>& vel, OpenMM::XmlSerializer::serialize<OpenMM::State>(object, "State", ss);
const std::vector<Vec3>& forces, return ss.str();
double kineticEnergy,
double potentialEnergy,
double time,
const std::vector<Vec3>& boxVectors,
const std::map<string, double>& params,
const std::map<string, double>& paramDerivs,
int types) {
OpenMM::State myState = _convertListsToState(pos,vel,forces,kineticEnergy,potentialEnergy,time,boxVectors,params,paramDerivs,types);
std::stringstream buffer;
OpenMM::XmlSerializer::serialize<OpenMM::State>(&myState, "State", buffer);
return buffer.str();
} }
static PyObject* _deserializeStringIntoLists(const std::string &stateAsString) { %newobject _deserializeState;
std::stringstream ss; static OpenMM::State* _deserializeState(const char* inputString) {
ss << stateAsString; std::stringstream ss;
OpenMM::State* deserializedState = OpenMM::XmlSerializer::deserialize<OpenMM::State>(ss); ss << inputString;
PyObject* obj = _convertStateToLists(*deserializedState); return OpenMM::XmlSerializer::deserialize<OpenMM::State>(ss);
delete deserializedState;
return obj;
} }
%pythoncode %{ %pythoncode %{
@staticmethod
def _serializeState(pythonState):
positions = []
velocities = []
forces = []
kineticEnergy = 0.0
potentialEnergy = 0.0
params = {}
paramDerivs = {}
types = 0
try:
positions = pythonState.getPositions().value_in_unit(unit.nanometers)
types |= 1
except:
pass
try:
velocities = pythonState.getVelocities().value_in_unit(unit.nanometers/unit.picoseconds)
types |= 2
except:
pass
try:
forces = pythonState.getForces().value_in_unit(unit.kilojoules_per_mole/unit.nanometers)
types |= 4
except:
pass
try:
kineticEnergy = pythonState.getKineticEnergy().value_in_unit(unit.kilojoules_per_mole)
potentialEnergy = pythonState.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
types |= 8
except:
pass
try:
params = pythonState.getParameters()
types |= 16
except:
pass
try:
params = pythonState.getEnergyParameterDerivatives()
types |= 32
except:
pass
time = pythonState.getTime().value_in_unit(unit.picoseconds)
boxVectors = pythonState.getPeriodicBoxVectors().value_in_unit(unit.nanometers)
string = XmlSerializer._serializeStateAsLists(positions, velocities, forces, kineticEnergy, potentialEnergy, time, boxVectors, params, paramDerivs, types)
return string
@staticmethod
def _deserializeState(pythonString):
(simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap, paramDerivMap) = XmlSerializer._deserializeStringIntoLists(pythonString)
state = State(simTime=simTime,
energy=energy,
coordList=coordList,
velList=velList,
forceList=forceList,
periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap,
paramDerivMap=paramDerivMap)
return state
@staticmethod @staticmethod
def serialize(object): def serialize(object):
"""Serialize an object as XML.""" """Serialize an object as XML."""
...@@ -566,3 +393,118 @@ Parameters: ...@@ -566,3 +393,118 @@ Parameters:
return OpenMM::XmlSerializer::clone<OpenMM::TabulatedFunction>(*self); return OpenMM::XmlSerializer::clone<OpenMM::TabulatedFunction>(*self);
} }
} }
%extend OpenMM::State {
%pythoncode %{
def __getstate__(self):
serializationString = XmlSerializer.serialize(self)
return serializationString
def __setstate__(self, serializationString):
system = XmlSerializer.deserialize(serializationString)
self.this = system.this
def __deepcopy__(self, memo):
return self.__copy__()
def getPeriodicBoxVectors(self, asNumpy=False):
"""Get the vectors defining the axes of the periodic box."""
vectors = _openmm.State_getPeriodicBoxVectors(self)
if asNumpy:
vectors = numpy.array(vectors)
return vectors*unit.nanometers
def getPositions(self, asNumpy=False):
"""Get the position of each particle with units.
Raises an exception if positions where not requested in
the context.getState() call.
Returns a list of Vec3s, unless asNumpy is True, in
which case a Numpy array of arrays will be returned.
"""
if asNumpy:
if '_positionsNumpy' not in dir(self):
self._positionsNumpy = numpy.empty([self._getNumParticles(), 3], numpy.float64)
self._getVectorAsNumpy(State.Positions, self._positionsNumpy)
self._positionsNumpy = self._positionsNumpy*unit.nanometers
return self._positionsNumpy
if '_positions' not in dir(self):
self._positions = self._getVectorAsVec3(State.Positions)*unit.nanometers
return self._positions
def getVelocities(self, asNumpy=False):
"""Get the velocity of each particle with units.
Raises an exception if velocities where not requested in
the context.getState() call.
Returns a list of Vec3s if asNumpy is False, or a Numpy
array if asNumpy is True.
"""
if asNumpy:
if '_velocitiesNumpy' not in dir(self):
self._velocitiesNumpy = numpy.empty([self._getNumParticles(), 3], numpy.float64)
self._getVectorAsNumpy(State.Velocities, self._velocitiesNumpy)
self._velocitiesNumpy = self._velocitiesNumpy*unit.nanometers/unit.picosecond
return self._velocitiesNumpy
if '_velocities' not in dir(self):
self._velocities = self._getVectorAsVec3(State.Velocities)*unit.nanometers/unit.picosecond
return self._velocities
def getForces(self, asNumpy=False):
"""Get the force acting on each particle with units.
Raises an exception if forces where not requested in
the context.getState() call.
Returns a list of Vec3s if asNumpy is False, or a Numpy
array if asNumpy is True.
"""
if asNumpy:
if '_forcesNumpy' not in dir(self):
self._forcesNumpy = numpy.empty([self._getNumParticles(), 3], numpy.float64)
self._getVectorAsNumpy(State.Forces, self._forcesNumpy)
self._forcesNumpy = self._forcesNumpy*unit.kilojoules_per_mole/unit.nanometer
return self._forcesNumpy
if '_forces' not in dir(self):
self._forces = self._getVectorAsVec3(State.Forces)*unit.kilojoules_per_mole/unit.nanometer
return self._forces
%}
int _getNumParticles() {
if ((self->getDataTypes() & State::Positions) != 0)
return self->getPositions().size();
if ((self->getDataTypes() & State::Velocities) != 0)
return self->getVelocities().size();
if ((self->getDataTypes() & State::Forces) != 0)
return self->getForces().size();
return 0;
}
PyObject* _getVectorAsVec3(State::DataType type) {
if (type == State::Positions)
return copyVVec3ToList(self->getPositions());
if (type == State::Velocities)
return copyVVec3ToList(self->getVelocities());
if (type == State::Forces)
return copyVVec3ToList(self->getForces());
PyErr_SetString(PyExc_ValueError, "Illegal type specified in _getVectorAsVec3");
return NULL;
}
void _getVectorAsNumpy(State::DataType type, PyObject* output) {
const std::vector<Vec3>* array;
if (type == State::Positions)
array = &self->getPositions();
else if (type == State::Velocities)
array = &self->getVelocities();
else if (type == State::Forces)
array = &self->getForces();
else {
PyErr_SetString(PyExc_ValueError, "Illegal type specified in _getVectorAsNumpy");
return;
}
void* data = PyArray_DATA((PyArrayObject*) output);
memcpy(data, &array[0][0], 3*sizeof(double)*array->size());
}
%newobject __copy__;
OpenMM::State* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::State>(*self);
}
}
...@@ -20,121 +20,6 @@ PyObject *copyVVec3ToList(std::vector<Vec3> vVec3) { ...@@ -20,121 +20,6 @@ PyObject *copyVVec3ToList(std::vector<Vec3> vVec3) {
return pyList; return pyList;
} }
State _convertListsToState( const std::vector<Vec3> &pos,
const std::vector<Vec3> &vel,
const std::vector<Vec3> &forces,
double kineticEnergy,
double potentialEnergy,
double time,
const std::vector<Vec3> &boxVectors,
const std::map<std::string, double> &params,
const std::map<std::string, double> &paramDerivs,
int types ) {
State::StateBuilder sb(time);
if(types & State::Positions)
sb.setPositions(pos);
if(types & State::Velocities)
sb.setVelocities(vel);
if(types & State::Forces)
sb.setForces(forces);
if(types & State::Energy)
sb.setEnergy(kineticEnergy, potentialEnergy);
if(types & State::Parameters)
sb.setParameters(params);
if(types & State::ParameterDerivatives)
sb.setEnergyParameterDerivatives(paramDerivs);
sb.setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
return sb.getState();
}
PyObject *_convertStateToLists(const State& state) {
double simTime;
PyObject *pPeriodicBoxVectorsList;
PyObject *pEnergy;
PyObject *pPositions;
PyObject *pVelocities;
PyObject *pForces;
PyObject *pyTuple;
PyObject *pParameters;
PyObject *pParameterDerivs;
simTime=state.getTime();
OpenMM::Vec3 myVecA;
OpenMM::Vec3 myVecB;
OpenMM::Vec3 myVecC;
state.getPeriodicBoxVectors(myVecA, myVecB, myVecC);
PyObject* mm = PyImport_AddModule("simtk.openmm");
PyObject* vec3 = PyObject_GetAttrString(mm, "Vec3");
PyObject* args1 = Py_BuildValue("(d,d,d)", myVecA[0], myVecA[1], myVecA[2]);
PyObject* args2 = Py_BuildValue("(d,d,d)", myVecB[0], myVecB[1], myVecB[2]);
PyObject* args3 = Py_BuildValue("(d,d,d)", myVecC[0], myVecC[1], myVecC[2]);
PyObject* pyVec1 = PyObject_CallObject(vec3, args1);
PyObject* pyVec2 = PyObject_CallObject(vec3, args2);
PyObject* pyVec3 = PyObject_CallObject(vec3, args3);
Py_DECREF(args1);
Py_DECREF(args2);
Py_DECREF(args3);
pPeriodicBoxVectorsList = Py_BuildValue("N,N,N", pyVec1, pyVec2, pyVec3);
try {
pPositions = copyVVec3ToList(state.getPositions());
}
catch (std::exception& ex) {
pPositions = Py_None;
Py_INCREF(Py_None);
}
try {
pVelocities = copyVVec3ToList(state.getVelocities());
}
catch (std::exception& ex) {
pVelocities = Py_None;
Py_INCREF(Py_None);
}
try {
pForces = copyVVec3ToList(state.getForces());
}
catch (std::exception& ex) {
pForces = Py_None;
Py_INCREF(Py_None);
}
try {
pEnergy = Py_BuildValue("(d,d)",
state.getKineticEnergy(),
state.getPotentialEnergy());
}
catch (std::exception& ex) {
pEnergy = Py_None;
Py_INCREF(Py_None);
}
try {
pParameters = PyDict_New();
const std::map<std::string, double>& params = state.getParameters();
for (std::map<std::string, double>::const_iterator iter = params.begin(); iter != params.end(); ++iter)
PyDict_SetItemString(pParameters, iter->first.c_str(), Py_BuildValue("d", iter->second));
}
catch (std::exception& ex) {
pParameters = Py_None;
Py_INCREF(Py_None);
}
try {
pParameterDerivs = PyDict_New();
const std::map<std::string, double>& params = state.getEnergyParameterDerivatives();
for (std::map<std::string, double>::const_iterator iter = params.begin(); iter != params.end(); ++iter)
PyDict_SetItemString(pParameterDerivs, iter->first.c_str(), Py_BuildValue("d", iter->second));
}
catch (std::exception& ex) {
pParameterDerivs = Py_None;
Py_INCREF(Py_None);
}
pyTuple=Py_BuildValue("(d,N,N,N,N,N,N,N)",
simTime, pPeriodicBoxVectorsList, pEnergy,
pPositions, pVelocities,
pForces, pParameters, pParameterDerivs);
return pyTuple;
}
} // namespace OpenMM } // namespace OpenMM
%} %}
......
...@@ -20,220 +20,6 @@ else: ...@@ -20,220 +20,6 @@ else:
import simtk.unit as unit import simtk.unit as unit
from simtk.openmm.vec3 import Vec3 from simtk.openmm.vec3 import Vec3
class State(_object):
"""
A State object records a snapshot of the
current state of a simulation at a point
in time. You create it by calling
getState() on a Context.
When a State is created, you specify what
information should be stored in it. This
saves time and memory by only copying in
the information that you actually want.
This is especially important for forces
and energies, since they may need to be
calculated. If you query a State object
for a piece of information which is not
available (because it was not requested
when the State was created), it will
return None.
In general return values are Python Units
(https://simtk.org/home/python_units).
Among other things Python Units provides a
container class, Quantity, which holds a
value and a representation of the value's
unit. Values can be integers, floats,
lists, numarrays, etc. Quantity objects
can be used in arithmetic operation just
like number, except they also keep track
of units. To extract the value from a
quantity, us the value_in_unit() method.
For example, to extract the value from a
length quantity, in units of nanometers,
do the following:
myLengthQuantity.value_in_unit(unit.nanometer)
"""
def __init__(self,
simTime=None,
energy=None,
coordList=None,
velList=None,
forceList=None,
periodicBoxVectorsList=None,
paramMap=None,
paramDerivMap=None):
self._simTime=simTime
self._periodicBoxVectorsList=periodicBoxVectorsList
self._periodicBoxVectorsListNumpy=None
if energy:
self._eK0=energy[0]
self._eP0=energy[1]
else:
self._eK0=None
self._eP0=None
self._coordList=coordList
self._coordListNumpy=None
self._velList=velList
self._velListNumpy=None
self._forceList=forceList
self._forceListNumpy=None
self._paramMap=paramMap
self._paramDerivMap=paramDerivMap
def __getstate__(self):
serializationString = XmlSerializer.serialize(self)
return serializationString
def __setstate__(self, serializationString):
dState = XmlSerializer.deserialize(serializationString)
# Safe provided no __slots__ or other weird things are used
self.__dict__.update(dState.__dict__)
def getTime(self):
"""Get the time for which this State was created."""
return self._simTime * unit.picosecond
def getPeriodicBoxVectors(self, asNumpy=False):
"""Get the three periodic box vectors if this state is from a
simulation using PBC ."""
if self._periodicBoxVectorsList is None:
raise TypeError('periodic box vectors were not available.')
if asNumpy:
if self._periodicBoxVectorsListNumpy is None:
self._periodicBoxVectorsListNumpy = \
numpy.array(self._periodicBoxVectorsList)
returnValue=self._periodicBoxVectorsListNumpy
else:
returnValue=self._periodicBoxVectorsList
returnValue = unit.Quantity(returnValue, unit.nanometers)
return returnValue
def getPeriodicBoxVolume(self):
"""Get the volume of the periodic box."""
a = self._periodicBoxVectorsList[0]
b = self._periodicBoxVectorsList[1]
c = self._periodicBoxVectorsList[2]
bcrossc = Vec3(b[1]*c[2]-b[2]*c[1], b[2]*c[0]-b[0]*c[2], b[0]*c[1]-b[1]*c[0])
return unit.Quantity(unit.dot(a, bcrossc), unit.nanometers*unit.nanometers*unit.nanometers)
def getPositions(self, asNumpy=False):
"""Get the position of each particle with units.
Raises an exception if postions where not requested in
the context.getState() call.
Returns a list of tuples, unless asNumpy is True, in
which case a Numpy array of arrays will be returned.
To remove the units, divide return value by unit.angstrom
or unit.nanometer. See the following for details:
https://simtk.org/home/python_units
"""
if self._coordList is None:
raise TypeError('Positions were not requested in getState() call, so are not available.')
if asNumpy:
if self._coordListNumpy is None:
self._coordListNumpy=numpy.array(self._coordList)
returnValue=self._coordListNumpy
else:
returnValue=self._coordList
returnValue = unit.Quantity(returnValue, unit.nanometers)
return returnValue
def getVelocities(self, asNumpy=False):
"""Get the velocity of each particle with units.
Raises an exception if velocities where not requested in
the context.getState() call.
Returns a list of tuples, unless asNumpy is True, in
which case a Numpy array of arrays will be returned.
To remove the units, you can divide the return value by
unit.angstrom/unit.picosecond or unit.meter/unit.second,
etc. See the following for details:
https://simtk.org/home/python_units
"""
if self._velList is None:
raise TypeError('Velocities were not requested in getState() call, so are not available.')
if asNumpy:
if self._velListNumpy is None:
self._velListNumpy=numpy.array(self._velList)
returnValue=self._velListNumpy
else:
returnValue=self._velList
returnValue = unit.Quantity(returnValue, unit.nanometers/unit.picosecond)
return returnValue
def getForces(self, asNumpy=False):
"""Get the force acting on each particle with units.
Raises an exception if forces where not requested in
the context.getState() call.
Returns a list of tuples, unless asNumpy is True, in
which case a Numpy array of arrays will be returned.
To remove the units, you can divide the return value by
unit.kilojoule_per_mole/unit.angstrom or
unit.calorie_per_mole/unit.nanometer, etc.
See the following for details:
https://simtk.org/home/python_units
"""
if self._forceList is None:
raise TypeError('Forces were not requested in getState() call, so are not available.')
if asNumpy:
if self._forceListNumpy is None:
self._forceListNumpy=numpy.array(self._forceList)
returnValue=self._forceListNumpy
else:
returnValue=self._forceList
returnValue = unit.Quantity(returnValue,
unit.kilojoule_per_mole/unit.nanometer)
return returnValue
def getKineticEnergy(self):
"""Get the total kinetic energy of the system with units.
To remove the units, you can divide the return value by
unit.kilojoule_per_mole or unit.calorie_per_mole, etc.
See the following for details:
https://simtk.org/home/python_units
"""
if self._eK0 is None:
raise TypeError('Energy was not requested in getState() call, so it is not available.')
return self._eK0 * unit.kilojoule_per_mole
def getPotentialEnergy(self):
"""Get the total potential energy of the system with units.
To remove the units, you can divide the return value by
unit.kilojoule_per_mole or unit.kilocalorie_per_mole, etc.
See the following for details:
https://simtk.org/home/python_units
"""
if self._eP0 is None:
raise TypeError('Energy was not requested in getState() call, so it is not available.')
return self._eP0 * unit.kilojoule_per_mole
def getParameters(self):
"""Get a map containing the values of all parameters.
"""
if self._paramMap is None:
raise TypeError('Parameters were not requested in getState() call, so are not available.')
return self._paramMap
def getEnergyParameterDerivatives(self):
"""Get a map containing derivatives of the potential energy with respect to context parameters.
In most cases derivatives are only calculated if the corresponding Force objects have been
specifically told to compute them. Otherwise, the values in the map will be zero. Likewise,
if multiple Forces depend on the same parameter but only some have been told to compute
derivatives with respect to it, the returned value will include only the contributions from
the Forces that were told to compute it."""
if self._paramDerivMap is None:
raise TypeError('Parameter derivatives were not requested in getState() call, so are not available.')
return self._paramDerivMap
%} %}
......
...@@ -59,6 +59,12 @@ class TestNumpyCompatibility(unittest.TestCase): ...@@ -59,6 +59,12 @@ class TestNumpyCompatibility(unittest.TestCase):
np.testing.assert_array_almost_equal(input.value_in_unit(unit.angstroms / unit.femtoseconds), np.testing.assert_array_almost_equal(input.value_in_unit(unit.angstroms / unit.femtoseconds),
output.value_in_unit(unit.angstroms / unit.femtoseconds)) output.value_in_unit(unit.angstroms / unit.femtoseconds))
def test_periodicBoxVectors(self):
output = self.simulation.context.getState(getVelocities=True).getPeriodicBoxVectors(asNumpy=True)
systemBox = self.simulation.system.getDefaultPeriodicBoxVectors()
for i in range(3):
np.testing.assert_array_almost_equal(systemBox[i].value_in_unit(unit.nanometers), output[i].value_in_unit(unit.nanometers))
def test_tabulatedFunction(self): def test_tabulatedFunction(self):
f = mm.CustomNonbondedForce('g(r)') f = mm.CustomNonbondedForce('g(r)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment