Commit 5d4f86a4 authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Added state serialization and deserialization support for Python

parent 898ddbe7
...@@ -30,6 +30,7 @@ namespace std { ...@@ -30,6 +30,7 @@ namespace std {
%template(vectorpairii) vector< pair<int,int> >; %template(vectorpairii) vector< pair<int,int> >;
%template(vectorstring) vector<string>; %template(vectorstring) vector<string>;
%template(mapstringstring) map<string,string>; %template(mapstringstring) map<string,string>;
%template(mapstringdouble) map<string,double>;
}; };
%include "windows.i" %include "windows.i"
...@@ -40,6 +41,7 @@ namespace std { ...@@ -40,6 +41,7 @@ namespace std {
#include <sstream> #include <sstream>
#include <exception> #include <exception>
#include <fstream>
#include "OpenMM.h" #include "OpenMM.h"
#include "OpenMMAmoeba.h" #include "OpenMMAmoeba.h"
#include "openmm/RPMDIntegrator.h" #include "openmm/RPMDIntegrator.h"
......
...@@ -200,10 +200,8 @@ Parameters: ...@@ -200,10 +200,8 @@ Parameters:
paramMap=paramMap) paramMap=paramMap)
return state return state
} }
} }
%extend OpenMM::NonbondedForce { %extend OpenMM::NonbondedForce {
%pythoncode { %pythoncode {
def addParticle_usingRVdw(self, charge, rVDW, epsilon): def addParticle_usingRVdw(self, charge, rVDW, epsilon):
...@@ -229,7 +227,6 @@ Parameters: ...@@ -229,7 +227,6 @@ Parameters:
} }
} }
%extend OpenMM::System { %extend OpenMM::System {
%pythoncode { %pythoncode {
def __getstate__(self): def __getstate__(self):
...@@ -245,6 +242,90 @@ Parameters: ...@@ -245,6 +242,90 @@ Parameters:
} }
} }
%extend OpenMM::XmlSerializer {
static std::string _serializeStateAsLists(
const std::vector<Vec3>& pos,
const std::vector<Vec3>& vel,
const std::vector<Vec3>& forces,
double kineticEnergy,
double potentialEnergy,
double time,
const std::vector<Vec3>& boxVectors,
const std::map<string, double>& params,
int types) {
OpenMM::State myState = _convertListsToState(pos,vel,forces,kineticEnergy,potentialEnergy,time,boxVectors,params,types);
std::stringstream buffer;
OpenMM::XmlSerializer::serialize<OpenMM::State>(&myState, "State", buffer);
return buffer.str();
}
static PyObject* _deserializeStringIntoLists(const std::string &filename) {
std::fstream stateFile(filename.c_str(), std::ios::in);
OpenMM::State* deserializedState = OpenMM::XmlSerializer::deserialize<OpenMM::State>(stateFile);
PyObject* obj = _convertStateToLists(*deserializedState);
delete deserializedState;
return obj;
}
%pythoncode {
@staticmethod
def serializeState(pythonState):
positions = []
velocities = []
forces = []
kineticEnergy = 0.0
potentialEnergy = 0.0
types = 0
try:
positions = pythonState.getPositions().value_in_unit(unit.nanometers)
types |= 1
except:
pass
try:
velocities = pythonState.getVelocities().value_in_unit(unit.nanometers/unit.picoseconds)
types |= 2
except:
pass
try:
forces = pythonState.getForces().value_in_unit(unit.kilojoules_per_mole/unit.nanometers)
types |= 4
except:
pass
try:
kineticEnergy = pythonState.getKineticEnergy().value_in_unit(unit.kilojoules_per_mole)
potentialEnergy = pythonState.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
types |= 8
except:
pass
try:
params = pythonState.getParameters()
types |= 16
except:
pass
time = pythonState.getTime().value_in_unit(unit.picoseconds)
boxVectors = pythonState.getPeriodicBoxVectors().value_in_unit(unit.nanometers)
string = XmlSerializer._serializeStateAsLists(positions, velocities, forces, kineticEnergy, potentialEnergy, time, boxVectors, params, types)
return string
@staticmethod
def deserializeState(pythonString):
print pythonString
print type(pythonString)
(simTime, periodicBoxVectorsList, energy, coordList, velList,
forceList, paramMap) = XmlSerializer._deserializeStringIntoLists(pythonString)
state = State(simTime=simTime,
energy=energy,
coordList=coordList,
velList=velList,
forceList=forceList,
periodicBoxVectorsList=periodicBoxVectorsList,
paramMap=paramMap)
return state
}
}
%extend OpenMM::CustomIntegrator { %extend OpenMM::CustomIntegrator {
PyObject* getPerDofVariable(int index) const { PyObject* getPerDofVariable(int index) const {
std::vector<Vec3> values; std::vector<Vec3> values;
......
...@@ -20,6 +20,30 @@ PyObject *copyVVec3ToList(std::vector<Vec3> vVec3) { ...@@ -20,6 +20,30 @@ PyObject *copyVVec3ToList(std::vector<Vec3> vVec3) {
return pyList; return pyList;
} }
State _convertListsToState( const std::vector<Vec3> &pos,
const std::vector<Vec3> &vel,
const std::vector<Vec3> &forces,
double kineticEnergy,
double potentialEnergy,
double time,
const std::vector<Vec3> &boxVectors,
const std::map<std::string, double> &params,
int types ) {
State::StateBuilder sb(time);
if(types & State::Positions)
sb.setPositions(pos);
if(types & State::Velocities)
sb.setVelocities(vel);
if(types & State::Forces)
sb.setForces(forces);
if(types & State::Energy)
sb.setEnergy(kineticEnergy, potentialEnergy);
if(types & State::Parameters)
sb.setParameters(params);
sb.setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
return sb.getState();
}
PyObject *_convertStateToLists(const State& state) { PyObject *_convertStateToLists(const State& state) {
double simTime; double simTime;
PyObject *pPeriodicBoxVectorsList; PyObject *pPeriodicBoxVectorsList;
......
...@@ -11,7 +11,6 @@ RVDW_PER_SIGMA=math.pow(2, 1/6.0)/2.0 ...@@ -11,7 +11,6 @@ RVDW_PER_SIGMA=math.pow(2, 1/6.0)/2.0
import simtk.unit as unit import simtk.unit as unit
class State(_object): class State(_object):
""" """
A State object records a snapshot of the A State object records a snapshot of the
...@@ -73,6 +72,14 @@ class State(_object): ...@@ -73,6 +72,14 @@ class State(_object):
self._forceListNumpy=None self._forceListNumpy=None
self._paramMap=paramMap self._paramMap=paramMap
def __getstate__(self):
serializationString = XmlSerializer.serializeState(self)
return serializationString
def __setstate__(self, serializationString):
state = XmlSerializer.deserializeState(serializationString)
self.this = state.this
def getTime(self): def getTime(self):
"""Get the time for which this State was created.""" """Get the time for which this State was created."""
return self._simTime * unit.picosecond return self._simTime * unit.picosecond
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment