Commit 2c467990 authored by Peter Eastman's avatar Peter Eastman
Browse files

Python API for parameter derivatives

parent 7e114fb5
......@@ -28,6 +28,7 @@
int getForces,
int getEnergy,
int getParameters,
int getParameterDerivatives,
int enforcePeriodic,
bitmask32t groups) {
State state;
......@@ -38,6 +39,7 @@
if (getForces) types |= State::Forces;
if (getEnergy) types |= State::Energy;
if (getParameters) types |= State::Parameters;
if (getParameterDerivatives) types |= State::ParameterDerivatives;
try {
state = self->getState(types, enforcePeriodic, groups);
}
......@@ -53,7 +55,7 @@
%pythoncode %{
def getState(self, getPositions=False, getVelocities=False,
getForces=False, getEnergy=False, getParameters=False,
enforcePeriodicBox=False, groups=-1):
getParameterDerivatives=False, enforcePeriodicBox=False, groups=-1):
"""Get a State object recording the current state information stored in this context.
Parameters
......@@ -66,8 +68,10 @@
whether to store the forces acting on particles in the State
getEnergy : bool=False
whether to store potential and kinetic energy in the State
getParameter : bool=False
getParameters : bool=False
whether to store context parameters in the State
getParameterDerivatives : bool=False
whether to store parameter derivatives in the State
enforcePeriodicBox : bool=False
if false, the position of each particle will be whatever position
is stored in the Context, regardless of periodic boundary conditions.
......@@ -79,9 +83,9 @@
can also be passed as an unsigned integer interpreted as a bitmask,
in which case group i will be included if (groups&(1<<i)) != 0.
"""
getP, getV, getF, getE, getPa, enforcePeriodic = map(bool,
getP, getV, getF, getE, getPa, getPd, enforcePeriodic = map(bool,
(getPositions, getVelocities, getForces, getEnergy, getParameters,
enforcePeriodicBox))
getParameterDerivatives, enforcePeriodicBox))
try:
# is the input integer-like?
......@@ -95,8 +99,8 @@
raise TypeError('%s is neither an int nor set' % groups)
(simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = \
self._getStateAsLists(getP, getV, getF, getE, getPa, enforcePeriodic, groups_mask)
forceList, paramMap, paramDerivMap) = \
self._getStateAsLists(getP, getV, getF, getE, getPa, getPd, enforcePeriodic, groups_mask)
state = State(simTime=simTime,
energy=energy,
......@@ -104,7 +108,8 @@
velList=velList,
forceList=forceList,
periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap)
paramMap=paramMap,
paramDerivMap=paramDerivMap)
return state
def setState(self, state):
......@@ -176,6 +181,7 @@ Parameters:
int getForces,
int getEnergy,
int getParameters,
int getParameterDerivatives,
int enforcePeriodic,
int groups) {
State state;
......@@ -186,6 +192,7 @@ Parameters:
if (getForces) types |= State::Forces;
if (getEnergy) types |= State::Energy;
if (getParameters) types |= State::Parameters;
if (getParameterDerivatives) types |= State::ParameterDerivatives;
try {
state = self->getState(copy, types, enforcePeriodic, groups);
}
......@@ -206,6 +213,7 @@ Parameters:
getForces=False,
getEnergy=False,
getParameters=False,
getParameterDerivatives=False,
enforcePeriodicBox=False,
groups=-1):
"""Get a State object recording the current state information about one copy of the system.
......@@ -222,8 +230,10 @@ Parameters:
whether to store the forces acting on particles in the State
getEnergy : bool=False
whether to store potential and kinetic energy in the State
getParameter : bool=False
getParameters : bool=False
whether to store context parameters in the State
getParameterDerivatives : bool=False
whether to store parameter derivatives in the State
enforcePeriodicBox : bool=False
if false, the position of each particle will be whatever position
is stored in the Context, regardless of periodic boundary conditions.
......@@ -235,9 +245,9 @@ Parameters:
can also be passed as an unsigned integer interpreted as a bitmask,
in which case group i will be included if (groups&(1<<i)) != 0.
"""
getP, getV, getF, getE, getPa, enforcePeriodic = map(bool,
getP, getV, getF, getE, getPa, getPd, enforcePeriodic = map(bool,
(getPositions, getVelocities, getForces, getEnergy, getParameters,
enforcePeriodicBox))
getParameterDerivatives, enforcePeriodicBox))
try:
# is the input integer-like?
......@@ -250,8 +260,8 @@ Parameters:
raise TypeError('%s is neither an int nor set' % groups)
(simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = \
self._getStateAsLists(copy, getP, getV, getF, getE, getPa, enforcePeriodic, groups_mask)
forceList, paramMap, paramDerivMap) = \
self._getStateAsLists(getP, getV, getF, getE, getPa, getPd, enforcePeriodic, groups_mask)
state = State(simTime=simTime,
energy=energy,
......@@ -259,7 +269,8 @@ Parameters:
velList=velList,
forceList=forceList,
periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap)
paramMap=paramMap,
paramDerivMap=paramDerivMap)
return state
%}
}
......@@ -361,8 +372,9 @@ Parameters:
double time,
const std::vector<Vec3>& boxVectors,
const std::map<string, double>& params,
const std::map<string, double>& paramDerivs,
int types) {
OpenMM::State myState = _convertListsToState(pos,vel,forces,kineticEnergy,potentialEnergy,time,boxVectors,params,types);
OpenMM::State myState = _convertListsToState(pos,vel,forces,kineticEnergy,potentialEnergy,time,boxVectors,params,paramDerivs,types);
std::stringstream buffer;
OpenMM::XmlSerializer::serialize<OpenMM::State>(&myState, "State", buffer);
return buffer.str();
......@@ -386,6 +398,7 @@ Parameters:
kineticEnergy = 0.0
potentialEnergy = 0.0
params = {}
paramDerivs = {}
types = 0
try:
positions = pythonState.getPositions().value_in_unit(unit.nanometers)
......@@ -413,16 +426,21 @@ Parameters:
types |= 16
except:
pass
try:
params = pythonState.getEnergyParameterDerivatives()
types |= 32
except:
pass
time = pythonState.getTime().value_in_unit(unit.picoseconds)
boxVectors = pythonState.getPeriodicBoxVectors().value_in_unit(unit.nanometers)
string = XmlSerializer._serializeStateAsLists(positions, velocities, forces, kineticEnergy, potentialEnergy, time, boxVectors, params, types)
string = XmlSerializer._serializeStateAsLists(positions, velocities, forces, kineticEnergy, potentialEnergy, time, boxVectors, params, paramDerivs, types)
return string
@staticmethod
def _deserializeState(pythonString):
(simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = XmlSerializer._deserializeStringIntoLists(pythonString)
forceList, paramMap, paramDerivMap) = XmlSerializer._deserializeStringIntoLists(pythonString)
state = State(simTime=simTime,
energy=energy,
......@@ -430,7 +448,8 @@ Parameters:
velList=velList,
forceList=forceList,
periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap)
paramMap=paramMap,
paramDerivMap=paramDerivMap)
return state
@staticmethod
......
......@@ -28,6 +28,7 @@ State _convertListsToState( const std::vector<Vec3> &pos,
double time,
const std::vector<Vec3> &boxVectors,
const std::map<std::string, double> &params,
const std::map<std::string, double> &paramDerivs,
int types ) {
State::StateBuilder sb(time);
if(types & State::Positions)
......@@ -40,6 +41,8 @@ State _convertListsToState( const std::vector<Vec3> &pos,
sb.setEnergy(kineticEnergy, potentialEnergy);
if(types & State::Parameters)
sb.setParameters(params);
if(types & State::ParameterDerivatives)
sb.setEnergyParameterDerivatives(paramDerivs);
sb.setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
return sb.getState();
}
......@@ -53,6 +56,7 @@ PyObject *_convertStateToLists(const State& state) {
PyObject *pForces;
PyObject *pyTuple;
PyObject *pParameters;
PyObject *pParameterDerivs;
simTime=state.getTime();
OpenMM::Vec3 myVecA;
......@@ -112,11 +116,21 @@ PyObject *_convertStateToLists(const State& state) {
pParameters = Py_None;
Py_INCREF(Py_None);
}
try {
pParameterDerivs = PyDict_New();
const std::map<std::string, double>& params = state.getEnergyParameterDerivatives();
for (std::map<std::string, double>::const_iterator iter = params.begin(); iter != params.end(); ++iter)
PyDict_SetItemString(pParameterDerivs, iter->first.c_str(), Py_BuildValue("d", iter->second));
}
catch (std::exception& ex) {
pParameterDerivs = Py_None;
Py_INCREF(Py_None);
}
pyTuple=Py_BuildValue("(d,N,N,N,N,N,N)",
pyTuple=Py_BuildValue("(d,N,N,N,N,N,N,N)",
simTime, pPeriodicBoxVectorsList, pEnergy,
pPositions, pVelocities,
pForces, pParameters);
pForces, pParameters, pParameterDerivs);
return pyTuple;
}
......
......@@ -63,7 +63,8 @@ class State(_object):
velList=None,
forceList=None,
periodicBoxVectorsList=None,
paramMap=None):
paramMap=None,
paramDerivMap=None):
self._simTime=simTime
self._periodicBoxVectorsList=periodicBoxVectorsList
self._periodicBoxVectorsListNumpy=None
......@@ -80,6 +81,7 @@ class State(_object):
self._forceList=forceList
self._forceListNumpy=None
self._paramMap=paramMap
self._paramDerivMap=paramDerivMap
def __getstate__(self):
serializationString = XmlSerializer.serialize(self)
......@@ -221,6 +223,18 @@ class State(_object):
raise TypeError('Parameters were not requested in getState() call, so are not available.')
return self._paramMap
def getEnergyParameterDerivatives(self):
"""Get a map containing derivatives of the potential energy with respect to context parameters.
In most cases derivatives are only calculated if the corresponding Force objects have been
specifically told to compute them. Otherwise, the values in the map will be zero. Likewise,
if multiple Forces depend on the same parameter but only some have been told to compute
derivatives with respect to it, the returned value will include only the contributions from
the Forces that were told to compute it."""
if self._paramDerivMap is None:
raise TypeError('Parameter derivatives were not requested in getState() call, so are not available.')
return self._paramDerivMap
%}
%pythonappend OpenMM::Context::Context %{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment