Commit 8fc1d49f authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Added Serialization support for states. Also added a unit-test.

parent 277ab730
...@@ -104,6 +104,7 @@ public: ...@@ -104,6 +104,7 @@ public:
const std::map<std::string, double>& getParameters() const; const std::map<std::string, double>& getParameters() const;
private: private:
friend class Context; friend class Context;
friend class StateProxy;
State(double time, int numParticles, int types); State(double time, int numParticles, int types);
std::vector<Vec3>& updPositions(); std::vector<Vec3>& updPositions();
std::vector<Vec3>& updVelocities(); std::vector<Vec3>& updVelocities();
......
#ifndef STATE_PROXY_H_
#define STATE_PROXY_H_
#include "openmm/serialization/XmlSerializer.h"
namespace OpenMM { // needs to be for friend class to work
class StateProxy : public SerializationProxy {
public:
StateProxy();
void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const;
};
}
#endif
\ No newline at end of file
...@@ -70,6 +70,7 @@ ...@@ -70,6 +70,7 @@
#include "openmm/serialization/PeriodicTorsionForceProxy.h" #include "openmm/serialization/PeriodicTorsionForceProxy.h"
#include "openmm/serialization/RBTorsionForceProxy.h" #include "openmm/serialization/RBTorsionForceProxy.h"
#include "openmm/serialization/SystemProxy.h" #include "openmm/serialization/SystemProxy.h"
#include "openmm/serialization/StateProxy.h"
#if defined(WIN32) #if defined(WIN32)
#include <windows.h> #include <windows.h>
...@@ -106,4 +107,5 @@ extern "C" void registerSerializationProxies() { ...@@ -106,4 +107,5 @@ extern "C" void registerSerializationProxies() {
SerializationProxy::registerProxy(typeid(PeriodicTorsionForce), new PeriodicTorsionForceProxy()); SerializationProxy::registerProxy(typeid(PeriodicTorsionForce), new PeriodicTorsionForceProxy());
SerializationProxy::registerProxy(typeid(RBTorsionForce), new RBTorsionForceProxy()); SerializationProxy::registerProxy(typeid(RBTorsionForce), new RBTorsionForceProxy());
SerializationProxy::registerProxy(typeid(System), new SystemProxy()); SerializationProxy::registerProxy(typeid(System), new SystemProxy());
SerializationProxy::registerProxy(typeid(State), new StateProxy());
} }
\ No newline at end of file
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2010 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/serialization/StateProxy.h"
#include <OpenMM.h>
#include <map>
using namespace std;
using namespace OpenMM;
StateProxy::StateProxy() : SerializationProxy("State") {
}
void StateProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const State& s = *reinterpret_cast<const State*>(object);
node.setDoubleProperty("time", s.getTime());
Vec3 a,b,c;
s.getPeriodicBoxVectors(a,b,c);
SerializationNode& boxVectorsNode = node.createChildNode("PeriodicBoxVectors");
boxVectorsNode.createChildNode("A").setDoubleProperty("x", a[0]).setDoubleProperty("y", a[1]).setDoubleProperty("z", a[2]);
boxVectorsNode.createChildNode("B").setDoubleProperty("x", b[0]).setDoubleProperty("y", b[1]).setDoubleProperty("z", b[2]);
boxVectorsNode.createChildNode("C").setDoubleProperty("x", c[0]).setDoubleProperty("y", c[1]).setDoubleProperty("z", c[2]);
try {
s.getParameters();
SerializationNode& parametersNode = node.createChildNode("Parameters");
map<string, double> stateParams = s.getParameters();
map<string, double>::const_iterator it;
for(it = stateParams.begin(); it!=stateParams.end();it++) {
parametersNode.setDoubleProperty(it->first, it->second);
}
} catch (const OpenMMException &) {
// do nothing
}
try {
s.getPotentialEnergy();
SerializationNode& energiesNode = node.createChildNode("Energies");
energiesNode.setDoubleProperty("PotentialEnergy", s.getPotentialEnergy());
energiesNode.setDoubleProperty("KineticEnergy", s.getKineticEnergy());
} catch (const OpenMMException &) {
// do nothing
}
try {
s.getPositions();
SerializationNode& positionsNode = node.createChildNode("Positions");
vector<Vec3> statePositions = s.getPositions();
for(int i=0; i<statePositions.size();i++) {
positionsNode.createChildNode("Position").setDoubleProperty("x", statePositions[i][0]).setDoubleProperty("y", statePositions[i][1]).setDoubleProperty("z", statePositions[i][2]);
}
} catch (const OpenMMException &) {
// do nothing
}
try {
s.getVelocities();
SerializationNode& velocitiesNode = node.createChildNode("Velocities");
vector<Vec3> stateVelocities = s.getVelocities();
for(int i=0; i<stateVelocities.size();i++) {
velocitiesNode.createChildNode("Velocity").setDoubleProperty("x", stateVelocities[i][0]).setDoubleProperty("y", stateVelocities[i][1]).setDoubleProperty("z", stateVelocities[i][2]);
}
} catch (const OpenMMException &) {
// do nothing
}
try {
s.getForces();
SerializationNode& forcesNode = node.createChildNode("Forces");
vector<Vec3> stateForces = s.getForces();
for(int i=0; i<stateForces.size();i++) {
forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]);
}
} catch (const OpenMMException &) {
// do nothing
}
}
void* StateProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1 && node.getIntProperty("version") != 2)
throw OpenMMException("Unsupported version number");
double outTime = node.getDoubleProperty("time");
const SerializationNode& boxVectorsNode = node.getChildNode("PeriodicBoxVectors");
const SerializationNode& AVec = boxVectorsNode.getChildNode("A");
Vec3 outAVec(AVec.getDoubleProperty("x"),AVec.getDoubleProperty("y"),AVec.getDoubleProperty("z"));
const SerializationNode& BVec = boxVectorsNode.getChildNode("B");
Vec3 outBVec(BVec.getDoubleProperty("x"),BVec.getDoubleProperty("y"),BVec.getDoubleProperty("z"));
const SerializationNode& CVec = boxVectorsNode.getChildNode("C");
Vec3 outCVec(CVec.getDoubleProperty("x"),CVec.getDoubleProperty("y"),CVec.getDoubleProperty("z"));
int types = 0;
map<string, double> outStateParams;
try {
const SerializationNode& parametersNode = node.getChildNode("Parameters");
// inStateParams is really a <string,double> pair, where string is the name and double is the value
// but we want to avoid casting a string to a double and instead use the built in routines,
map<string, string> inStateParams = parametersNode.getProperties();
for(map<string, string>::const_iterator pit = inStateParams.begin(); pit != inStateParams.end(); pit++) {
outStateParams[pit->first] = parametersNode.getDoubleProperty(pit->first);
}
types = types | State::Parameters;
} catch (const OpenMMException &) {
// do nothing
}
double potentialEnergy;
double kineticEnergy;
try {
const SerializationNode& energiesNode = node.getChildNode("Energies");
potentialEnergy = energiesNode.getDoubleProperty("PotentialEnergy");
kineticEnergy = energiesNode.getDoubleProperty("KineticEnergy");
types = types | State::Energy;
} catch (const OpenMMException &) {
// do nothing
}
vector<Vec3> outPositions;
vector<Vec3> outVelocities;
vector<Vec3> outForces;
try {
const SerializationNode& positionsNode = node.getChildNode("Positions");
for(int i=0; i<(int) positionsNode.getChildren().size();i++) {
const SerializationNode& particle = positionsNode.getChildren()[i];
outPositions.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
}
types = types | State::Positions;
} catch (const OpenMMException &) {
// do nothing
}
try {
const SerializationNode& velocitiesNode = node.getChildNode("Velocities");
for(int i=0; i<(int) velocitiesNode.getChildren().size();i++) {
const SerializationNode& particle = velocitiesNode.getChildren()[i];
outVelocities.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
}
types = types | State::Velocities;
} catch (const OpenMMException &) {
// do nothing
}
try {
const SerializationNode& forcesNode = node.getChildNode("Forces");
for(int i=0; i<(int) forcesNode.getChildren().size();i++) {
const SerializationNode& particle = forcesNode.getChildren()[i];
outForces.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
}
types = types | State::Forces;
} catch (const OpenMMException &) {
// do nothing
}
int numParticles = max(outPositions.size(), max(outForces.size(), outVelocities.size()));
vector<int> arraySizes;
State *s = new State(outTime,numParticles,types);
if(types & State::Positions) {
s->updPositions() = outPositions;
arraySizes.push_back(outPositions.size());
}
if(types & State::Velocities) {
s->updVelocities() = outVelocities;
arraySizes.push_back(outVelocities.size());
}
if(types & State::Forces) {
s->updForces() = outForces;
arraySizes.push_back(outVelocities.size());
}
if(types & State::Energy) {
s->setEnergy(kineticEnergy, potentialEnergy);
}
if(types & State::Parameters) {
s->updParameters() = outStateParams;
}
for(int i=1; i<arraySizes.size();i++) {
if(arraySizes[i] != arraySizes[i-1]) {
throw(OpenMMException("State Deserialization Particle Size Mismatch, check number of particles in Forces, Velocities, Positions!"));
}
}
s->setPeriodicBoxVectors(outAVec, outBVec, outCVec);
return s;
}
\ No newline at end of file
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2010 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h"
#include "openmm/System.h"
#include "openmm/Context.h"
#include "openmm/LangevinIntegrator.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/MonteCarloBarostat.h"
#include "openmm/serialization/XmlSerializer.h"
#include <iostream>
#include <sstream>
using namespace OpenMM;
using namespace std;
void testSerialization() {
// Create a System.
const int numParticles=50;
System system;
system.setDefaultPeriodicBoxVectors(Vec3(6.2, 0, 0), Vec3(0, 6.2, 0), Vec3(0, 0, 6.2 ));
NonbondedForce* nonbonded = new NonbondedForce();
nonbonded->setNonbondedMethod(NonbondedForce::Ewald);
nonbonded->setCutoffDistance(0.8);
nonbonded->setEwaldErrorTolerance(0.01);
for (int i = 0; i < numParticles/2; i++)
system.addParticle(22.99);
for (int i = 0; i < numParticles/2; i++)
system.addParticle(35.45);
for (int i = 0; i < numParticles/2; i++)
nonbonded->addParticle(1.0, 1.0,0.0);
for (int i = 0; i < numParticles/2; i++)
nonbonded->addParticle(-1.0, 1.0,0.0);
system.addForce(nonbonded);
system.addForce(new AndersenThermostat(393.3, 19.3));
system.addForce(new MonteCarloBarostat(25, 393.3, 25));
Integrator *intg = new LangevinIntegrator(300,79,0.002);
Context *ctxt = new Context(system, *intg);
// Set positions, velocities, forces
vector<Vec3> positions;
for(int i=0;i<numParticles;i++) {
positions.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2));
}
vector<Vec3> velocities;
for(int i=0;i<numParticles;i++) {
velocities.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2));
}
ctxt->setPositions(positions);
ctxt->setVelocities(velocities);
// Serialize and then deserialize it.
State s1 = ctxt->getState(State::Positions | State::Velocities | State::Forces | State::Energy | State::Parameters);
stringstream buffer;
XmlSerializer::serialize<State>(&s1, "State", buffer);
State* copy = XmlSerializer::deserialize<State>(buffer);
State& s2 = *copy;
// Compare the two states to see if they are identical.
vector<Vec3> pos1 = s1.getPositions();
vector<Vec3> pos2 = s2.getPositions();
ASSERT_EQUAL(pos1.size(), pos2.size());
ASSERT_EQUAL(pos1.size(), positions.size());
for(int i=0; i<pos1.size(); i++) {
ASSERT_EQUAL_VEC(pos1[i],pos2[i],0);
}
vector<Vec3> vel1 = s1.getVelocities();
vector<Vec3> vel2 = s2.getVelocities();
ASSERT_EQUAL(vel1.size(), vel2.size());
for(int i=0; i<pos1.size(); i++) {
ASSERT_EQUAL_VEC(vel1[i],vel2[i],0);
}
vector<Vec3> forces1 = s1.getForces();
vector<Vec3> forces2 = s2.getForces();
ASSERT_EQUAL(forces1.size(), forces2.size());
for(int i=0; i<pos1.size(); i++) {
ASSERT_EQUAL_VEC(forces1[i],forces2[i],0);
}
Vec3 a1,a2,a3,b1,b2,b3;
s1.getPeriodicBoxVectors(a1,a2,a3);
s2.getPeriodicBoxVectors(b1,b2,b3);
ASSERT_EQUAL_VEC(a1,b1,0);
ASSERT_EQUAL_VEC(a2,b2,0);
ASSERT_EQUAL_VEC(a3,b3,0);
ASSERT_EQUAL(s1.getPotentialEnergy(), s2.getPotentialEnergy());
ASSERT_EQUAL(s1.getKineticEnergy(), s2.getKineticEnergy());
ASSERT_EQUAL(s1.getTime(), s2.getTime());
map<string, double> p1 = s1.getParameters();
map<string, double> p2 = s2.getParameters();
ASSERT_EQUAL(p1.size(), p2.size());
map<string, double>::const_iterator it1=p1.begin();
map<string, double>::const_iterator it2=p2.begin();
//maps are ordered, so iterators should be in the same order.
for(it1=p1.begin(); it1!=p1.end(); it1++, it2++) {
assert((it1->first).compare(it2->first) == 0);
ASSERT_EQUAL(it1->second, it2->second);
}
}
int main() {
try {
testSerialization();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment