Commit 2c467990 authored by Peter Eastman's avatar Peter Eastman
Browse files

Python API for parameter derivatives

parent 7e114fb5
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
int getForces, int getForces,
int getEnergy, int getEnergy,
int getParameters, int getParameters,
int getParameterDerivatives,
int enforcePeriodic, int enforcePeriodic,
bitmask32t groups) { bitmask32t groups) {
State state; State state;
...@@ -38,6 +39,7 @@ ...@@ -38,6 +39,7 @@
if (getForces) types |= State::Forces; if (getForces) types |= State::Forces;
if (getEnergy) types |= State::Energy; if (getEnergy) types |= State::Energy;
if (getParameters) types |= State::Parameters; if (getParameters) types |= State::Parameters;
if (getParameterDerivatives) types |= State::ParameterDerivatives;
try { try {
state = self->getState(types, enforcePeriodic, groups); state = self->getState(types, enforcePeriodic, groups);
} }
...@@ -53,7 +55,7 @@ ...@@ -53,7 +55,7 @@
%pythoncode %{ %pythoncode %{
def getState(self, getPositions=False, getVelocities=False, def getState(self, getPositions=False, getVelocities=False,
getForces=False, getEnergy=False, getParameters=False, getForces=False, getEnergy=False, getParameters=False,
enforcePeriodicBox=False, groups=-1): getParameterDerivatives=False, enforcePeriodicBox=False, groups=-1):
"""Get a State object recording the current state information stored in this context. """Get a State object recording the current state information stored in this context.
Parameters Parameters
...@@ -66,8 +68,10 @@ ...@@ -66,8 +68,10 @@
whether to store the forces acting on particles in the State whether to store the forces acting on particles in the State
getEnergy : bool=False getEnergy : bool=False
whether to store potential and kinetic energy in the State whether to store potential and kinetic energy in the State
getParameter : bool=False getParameters : bool=False
whether to store context parameters in the State whether to store context parameters in the State
getParameterDerivatives : bool=False
whether to store parameter derivatives in the State
enforcePeriodicBox : bool=False enforcePeriodicBox : bool=False
if false, the position of each particle will be whatever position if false, the position of each particle will be whatever position
is stored in the Context, regardless of periodic boundary conditions. is stored in the Context, regardless of periodic boundary conditions.
...@@ -79,9 +83,9 @@ ...@@ -79,9 +83,9 @@
can also be passed as an unsigned integer interpreted as a bitmask, can also be passed as an unsigned integer interpreted as a bitmask,
in which case group i will be included if (groups&(1<<i)) != 0. in which case group i will be included if (groups&(1<<i)) != 0.
""" """
getP, getV, getF, getE, getPa, enforcePeriodic = map(bool, getP, getV, getF, getE, getPa, getPd, enforcePeriodic = map(bool,
(getPositions, getVelocities, getForces, getEnergy, getParameters, (getPositions, getVelocities, getForces, getEnergy, getParameters,
enforcePeriodicBox)) getParameterDerivatives, enforcePeriodicBox))
try: try:
# is the input integer-like? # is the input integer-like?
...@@ -95,8 +99,8 @@ ...@@ -95,8 +99,8 @@
raise TypeError('%s is neither an int nor set' % groups) raise TypeError('%s is neither an int nor set' % groups)
(simTime, periodicBoxVectorsList, energy, coordList, velList, (simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = \ forceList, paramMap, paramDerivMap) = \
self._getStateAsLists(getP, getV, getF, getE, getPa, enforcePeriodic, groups_mask) self._getStateAsLists(getP, getV, getF, getE, getPa, getPd, enforcePeriodic, groups_mask)
state = State(simTime=simTime, state = State(simTime=simTime,
energy=energy, energy=energy,
...@@ -104,7 +108,8 @@ ...@@ -104,7 +108,8 @@
velList=velList, velList=velList,
forceList=forceList, forceList=forceList,
periodicBoxVectorsList=periodicBoxVectorsList, periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap) paramMap=paramMap,
paramDerivMap=paramDerivMap)
return state return state
def setState(self, state): def setState(self, state):
...@@ -176,6 +181,7 @@ Parameters: ...@@ -176,6 +181,7 @@ Parameters:
int getForces, int getForces,
int getEnergy, int getEnergy,
int getParameters, int getParameters,
int getParameterDerivatives,
int enforcePeriodic, int enforcePeriodic,
int groups) { int groups) {
State state; State state;
...@@ -186,6 +192,7 @@ Parameters: ...@@ -186,6 +192,7 @@ Parameters:
if (getForces) types |= State::Forces; if (getForces) types |= State::Forces;
if (getEnergy) types |= State::Energy; if (getEnergy) types |= State::Energy;
if (getParameters) types |= State::Parameters; if (getParameters) types |= State::Parameters;
if (getParameterDerivatives) types |= State::ParameterDerivatives;
try { try {
state = self->getState(copy, types, enforcePeriodic, groups); state = self->getState(copy, types, enforcePeriodic, groups);
} }
...@@ -206,6 +213,7 @@ Parameters: ...@@ -206,6 +213,7 @@ Parameters:
getForces=False, getForces=False,
getEnergy=False, getEnergy=False,
getParameters=False, getParameters=False,
getParameterDerivatives=False,
enforcePeriodicBox=False, enforcePeriodicBox=False,
groups=-1): groups=-1):
"""Get a State object recording the current state information about one copy of the system. """Get a State object recording the current state information about one copy of the system.
...@@ -222,8 +230,10 @@ Parameters: ...@@ -222,8 +230,10 @@ Parameters:
whether to store the forces acting on particles in the State whether to store the forces acting on particles in the State
getEnergy : bool=False getEnergy : bool=False
whether to store potential and kinetic energy in the State whether to store potential and kinetic energy in the State
getParameter : bool=False getParameters : bool=False
whether to store context parameters in the State whether to store context parameters in the State
getParameterDerivatives : bool=False
whether to store parameter derivatives in the State
enforcePeriodicBox : bool=False enforcePeriodicBox : bool=False
if false, the position of each particle will be whatever position if false, the position of each particle will be whatever position
is stored in the Context, regardless of periodic boundary conditions. is stored in the Context, regardless of periodic boundary conditions.
...@@ -235,9 +245,9 @@ Parameters: ...@@ -235,9 +245,9 @@ Parameters:
can also be passed as an unsigned integer interpreted as a bitmask, can also be passed as an unsigned integer interpreted as a bitmask,
in which case group i will be included if (groups&(1<<i)) != 0. in which case group i will be included if (groups&(1<<i)) != 0.
""" """
getP, getV, getF, getE, getPa, enforcePeriodic = map(bool, getP, getV, getF, getE, getPa, getPd, enforcePeriodic = map(bool,
(getPositions, getVelocities, getForces, getEnergy, getParameters, (getPositions, getVelocities, getForces, getEnergy, getParameters,
enforcePeriodicBox)) getParameterDerivatives, enforcePeriodicBox))
try: try:
# is the input integer-like? # is the input integer-like?
...@@ -250,8 +260,8 @@ Parameters: ...@@ -250,8 +260,8 @@ Parameters:
raise TypeError('%s is neither an int nor set' % groups) raise TypeError('%s is neither an int nor set' % groups)
(simTime, periodicBoxVectorsList, energy, coordList, velList, (simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = \ forceList, paramMap, paramDerivMap) = \
self._getStateAsLists(copy, getP, getV, getF, getE, getPa, enforcePeriodic, groups_mask) self._getStateAsLists(getP, getV, getF, getE, getPa, getPd, enforcePeriodic, groups_mask)
state = State(simTime=simTime, state = State(simTime=simTime,
energy=energy, energy=energy,
...@@ -259,7 +269,8 @@ Parameters: ...@@ -259,7 +269,8 @@ Parameters:
velList=velList, velList=velList,
forceList=forceList, forceList=forceList,
periodicBoxVectorsList=periodicBoxVectorsList, periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap) paramMap=paramMap,
paramDerivMap=paramDerivMap)
return state return state
%} %}
} }
...@@ -361,8 +372,9 @@ Parameters: ...@@ -361,8 +372,9 @@ Parameters:
double time, double time,
const std::vector<Vec3>& boxVectors, const std::vector<Vec3>& boxVectors,
const std::map<string, double>& params, const std::map<string, double>& params,
const std::map<string, double>& paramDerivs,
int types) { int types) {
OpenMM::State myState = _convertListsToState(pos,vel,forces,kineticEnergy,potentialEnergy,time,boxVectors,params,types); OpenMM::State myState = _convertListsToState(pos,vel,forces,kineticEnergy,potentialEnergy,time,boxVectors,params,paramDerivs,types);
std::stringstream buffer; std::stringstream buffer;
OpenMM::XmlSerializer::serialize<OpenMM::State>(&myState, "State", buffer); OpenMM::XmlSerializer::serialize<OpenMM::State>(&myState, "State", buffer);
return buffer.str(); return buffer.str();
...@@ -386,6 +398,7 @@ Parameters: ...@@ -386,6 +398,7 @@ Parameters:
kineticEnergy = 0.0 kineticEnergy = 0.0
potentialEnergy = 0.0 potentialEnergy = 0.0
params = {} params = {}
paramDerivs = {}
types = 0 types = 0
try: try:
positions = pythonState.getPositions().value_in_unit(unit.nanometers) positions = pythonState.getPositions().value_in_unit(unit.nanometers)
...@@ -413,16 +426,21 @@ Parameters: ...@@ -413,16 +426,21 @@ Parameters:
types |= 16 types |= 16
except: except:
pass pass
try:
params = pythonState.getEnergyParameterDerivatives()
types |= 32
except:
pass
time = pythonState.getTime().value_in_unit(unit.picoseconds) time = pythonState.getTime().value_in_unit(unit.picoseconds)
boxVectors = pythonState.getPeriodicBoxVectors().value_in_unit(unit.nanometers) boxVectors = pythonState.getPeriodicBoxVectors().value_in_unit(unit.nanometers)
string = XmlSerializer._serializeStateAsLists(positions, velocities, forces, kineticEnergy, potentialEnergy, time, boxVectors, params, types) string = XmlSerializer._serializeStateAsLists(positions, velocities, forces, kineticEnergy, potentialEnergy, time, boxVectors, params, paramDerivs, types)
return string return string
@staticmethod @staticmethod
def _deserializeState(pythonString): def _deserializeState(pythonString):
(simTime, periodicBoxVectorsList, energy, coordList, velList, (simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = XmlSerializer._deserializeStringIntoLists(pythonString) forceList, paramMap, paramDerivMap) = XmlSerializer._deserializeStringIntoLists(pythonString)
state = State(simTime=simTime, state = State(simTime=simTime,
energy=energy, energy=energy,
...@@ -430,7 +448,8 @@ Parameters: ...@@ -430,7 +448,8 @@ Parameters:
velList=velList, velList=velList,
forceList=forceList, forceList=forceList,
periodicBoxVectorsList=periodicBoxVectorsList, periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap) paramMap=paramMap,
paramDerivMap=paramDerivMap)
return state return state
@staticmethod @staticmethod
......
...@@ -28,6 +28,7 @@ State _convertListsToState( const std::vector<Vec3> &pos, ...@@ -28,6 +28,7 @@ State _convertListsToState( const std::vector<Vec3> &pos,
double time, double time,
const std::vector<Vec3> &boxVectors, const std::vector<Vec3> &boxVectors,
const std::map<std::string, double> &params, const std::map<std::string, double> &params,
const std::map<std::string, double> &paramDerivs,
int types ) { int types ) {
State::StateBuilder sb(time); State::StateBuilder sb(time);
if(types & State::Positions) if(types & State::Positions)
...@@ -40,6 +41,8 @@ State _convertListsToState( const std::vector<Vec3> &pos, ...@@ -40,6 +41,8 @@ State _convertListsToState( const std::vector<Vec3> &pos,
sb.setEnergy(kineticEnergy, potentialEnergy); sb.setEnergy(kineticEnergy, potentialEnergy);
if(types & State::Parameters) if(types & State::Parameters)
sb.setParameters(params); sb.setParameters(params);
if(types & State::ParameterDerivatives)
sb.setEnergyParameterDerivatives(paramDerivs);
sb.setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]); sb.setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
return sb.getState(); return sb.getState();
} }
...@@ -53,6 +56,7 @@ PyObject *_convertStateToLists(const State& state) { ...@@ -53,6 +56,7 @@ PyObject *_convertStateToLists(const State& state) {
PyObject *pForces; PyObject *pForces;
PyObject *pyTuple; PyObject *pyTuple;
PyObject *pParameters; PyObject *pParameters;
PyObject *pParameterDerivs;
simTime=state.getTime(); simTime=state.getTime();
OpenMM::Vec3 myVecA; OpenMM::Vec3 myVecA;
...@@ -112,11 +116,21 @@ PyObject *_convertStateToLists(const State& state) { ...@@ -112,11 +116,21 @@ PyObject *_convertStateToLists(const State& state) {
pParameters = Py_None; pParameters = Py_None;
Py_INCREF(Py_None); Py_INCREF(Py_None);
} }
try {
pParameterDerivs = PyDict_New();
const std::map<std::string, double>& params = state.getEnergyParameterDerivatives();
for (std::map<std::string, double>::const_iterator iter = params.begin(); iter != params.end(); ++iter)
PyDict_SetItemString(pParameterDerivs, iter->first.c_str(), Py_BuildValue("d", iter->second));
}
catch (std::exception& ex) {
pParameterDerivs = Py_None;
Py_INCREF(Py_None);
}
pyTuple=Py_BuildValue("(d,N,N,N,N,N,N)", pyTuple=Py_BuildValue("(d,N,N,N,N,N,N,N)",
simTime, pPeriodicBoxVectorsList, pEnergy, simTime, pPeriodicBoxVectorsList, pEnergy,
pPositions, pVelocities, pPositions, pVelocities,
pForces, pParameters); pForces, pParameters, pParameterDerivs);
return pyTuple; return pyTuple;
} }
......
...@@ -63,7 +63,8 @@ class State(_object): ...@@ -63,7 +63,8 @@ class State(_object):
velList=None, velList=None,
forceList=None, forceList=None,
periodicBoxVectorsList=None, periodicBoxVectorsList=None,
paramMap=None): paramMap=None,
paramDerivMap=None):
self._simTime=simTime self._simTime=simTime
self._periodicBoxVectorsList=periodicBoxVectorsList self._periodicBoxVectorsList=periodicBoxVectorsList
self._periodicBoxVectorsListNumpy=None self._periodicBoxVectorsListNumpy=None
...@@ -80,6 +81,7 @@ class State(_object): ...@@ -80,6 +81,7 @@ class State(_object):
self._forceList=forceList self._forceList=forceList
self._forceListNumpy=None self._forceListNumpy=None
self._paramMap=paramMap self._paramMap=paramMap
self._paramDerivMap=paramDerivMap
def __getstate__(self): def __getstate__(self):
serializationString = XmlSerializer.serialize(self) serializationString = XmlSerializer.serialize(self)
...@@ -221,6 +223,18 @@ class State(_object): ...@@ -221,6 +223,18 @@ class State(_object):
raise TypeError('Parameters were not requested in getState() call, so are not available.') raise TypeError('Parameters were not requested in getState() call, so are not available.')
return self._paramMap return self._paramMap
def getEnergyParameterDerivatives(self):
"""Get a map containing derivatives of the potential energy with respect to context parameters.
In most cases derivatives are only calculated if the corresponding Force objects have been
specifically told to compute them. Otherwise, the values in the map will be zero. Likewise,
if multiple Forces depend on the same parameter but only some have been told to compute
derivatives with respect to it, the returned value will include only the contributions from
the Forces that were told to compute it."""
if self._paramDerivMap is None:
raise TypeError('Parameter derivatives were not requested in getState() call, so are not available.')
return self._paramDerivMap
%} %}
%pythonappend OpenMM::Context::Context %{ %pythonappend OpenMM::Context::Context %{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment