Commit e670188b authored by peastman's avatar peastman
Browse files

Merge pull request #829 from peastman/state

Avoid throwing exceptions internally when serializing States
parents f11bafd9 e18eb785
...@@ -108,6 +108,10 @@ public: ...@@ -108,6 +108,10 @@ public:
* Get a map containing the values of all parameters. If this State does not contain parameters, this will throw an exception. * Get a map containing the values of all parameters. If this State does not contain parameters, this will throw an exception.
*/ */
const std::map<std::string, double>& getParameters() const; const std::map<std::string, double>& getParameters() const;
/**
* Get which data types are stored in this State. The return value is a sum of DataType flags.
*/
int getDataTypes() const;
private: private:
State(double time); State(double time);
void setPositions(const std::vector<Vec3>& pos); void setPositions(const std::vector<Vec3>& pos);
......
...@@ -76,6 +76,9 @@ const map<string, double>& State::getParameters() const { ...@@ -76,6 +76,9 @@ const map<string, double>& State::getParameters() const {
throw OpenMMException("Invoked getParameters() on a State which does not contain parameters."); throw OpenMMException("Invoked getParameters() on a State which does not contain parameters.");
return parameters; return parameters;
} }
int State::getDataTypes() const {
return types;
}
State::State(double time) : types(0), time(time), ke(0), pe(0) { State::State(double time) : types(0), time(time), ke(0), pe(0) {
} }
State::State() : types(0), time(0.0), ke(0), pe(0) { State::State() : types(0), time(0.0), ke(0), pe(0) {
......
...@@ -50,7 +50,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const { ...@@ -50,7 +50,7 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
boxVectorsNode.createChildNode("A").setDoubleProperty("x", a[0]).setDoubleProperty("y", a[1]).setDoubleProperty("z", a[2]); boxVectorsNode.createChildNode("A").setDoubleProperty("x", a[0]).setDoubleProperty("y", a[1]).setDoubleProperty("z", a[2]);
boxVectorsNode.createChildNode("B").setDoubleProperty("x", b[0]).setDoubleProperty("y", b[1]).setDoubleProperty("z", b[2]); boxVectorsNode.createChildNode("B").setDoubleProperty("x", b[0]).setDoubleProperty("y", b[1]).setDoubleProperty("z", b[2]);
boxVectorsNode.createChildNode("C").setDoubleProperty("x", c[0]).setDoubleProperty("y", c[1]).setDoubleProperty("z", c[2]); boxVectorsNode.createChildNode("C").setDoubleProperty("x", c[0]).setDoubleProperty("y", c[1]).setDoubleProperty("z", c[2]);
try { if ((s.getDataTypes()&State::Parameters) != 0) {
s.getParameters(); s.getParameters();
SerializationNode& parametersNode = node.createChildNode("Parameters"); SerializationNode& parametersNode = node.createChildNode("Parameters");
map<string, double> stateParams = s.getParameters(); map<string, double> stateParams = s.getParameters();
...@@ -58,46 +58,36 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const { ...@@ -58,46 +58,36 @@ void StateProxy::serialize(const void* object, SerializationNode& node) const {
for (it = stateParams.begin(); it!=stateParams.end();it++) { for (it = stateParams.begin(); it!=stateParams.end();it++) {
parametersNode.setDoubleProperty(it->first, it->second); parametersNode.setDoubleProperty(it->first, it->second);
} }
} catch (const OpenMMException &) {
// do nothing
} }
try { if ((s.getDataTypes()&State::Energy) != 0) {
s.getPotentialEnergy(); s.getPotentialEnergy();
SerializationNode& energiesNode = node.createChildNode("Energies"); SerializationNode& energiesNode = node.createChildNode("Energies");
energiesNode.setDoubleProperty("PotentialEnergy", s.getPotentialEnergy()); energiesNode.setDoubleProperty("PotentialEnergy", s.getPotentialEnergy());
energiesNode.setDoubleProperty("KineticEnergy", s.getKineticEnergy()); energiesNode.setDoubleProperty("KineticEnergy", s.getKineticEnergy());
} catch (const OpenMMException &) {
// do nothing
} }
try { if ((s.getDataTypes()&State::Positions) != 0) {
s.getPositions(); s.getPositions();
SerializationNode& positionsNode = node.createChildNode("Positions"); SerializationNode& positionsNode = node.createChildNode("Positions");
vector<Vec3> statePositions = s.getPositions(); vector<Vec3> statePositions = s.getPositions();
for (int i=0; i<statePositions.size();i++) { for (int i=0; i<statePositions.size();i++) {
positionsNode.createChildNode("Position").setDoubleProperty("x", statePositions[i][0]).setDoubleProperty("y", statePositions[i][1]).setDoubleProperty("z", statePositions[i][2]); positionsNode.createChildNode("Position").setDoubleProperty("x", statePositions[i][0]).setDoubleProperty("y", statePositions[i][1]).setDoubleProperty("z", statePositions[i][2]);
} }
} catch (const OpenMMException &) {
// do nothing
} }
try { if ((s.getDataTypes()&State::Velocities) != 0) {
s.getVelocities(); s.getVelocities();
SerializationNode& velocitiesNode = node.createChildNode("Velocities"); SerializationNode& velocitiesNode = node.createChildNode("Velocities");
vector<Vec3> stateVelocities = s.getVelocities(); vector<Vec3> stateVelocities = s.getVelocities();
for (int i=0; i<stateVelocities.size();i++) { for (int i=0; i<stateVelocities.size();i++) {
velocitiesNode.createChildNode("Velocity").setDoubleProperty("x", stateVelocities[i][0]).setDoubleProperty("y", stateVelocities[i][1]).setDoubleProperty("z", stateVelocities[i][2]); velocitiesNode.createChildNode("Velocity").setDoubleProperty("x", stateVelocities[i][0]).setDoubleProperty("y", stateVelocities[i][1]).setDoubleProperty("z", stateVelocities[i][2]);
} }
} catch (const OpenMMException &) {
// do nothing
} }
try { if ((s.getDataTypes()&State::Forces) != 0) {
s.getForces(); s.getForces();
SerializationNode& forcesNode = node.createChildNode("Forces"); SerializationNode& forcesNode = node.createChildNode("Forces");
vector<Vec3> stateForces = s.getForces(); vector<Vec3> stateForces = s.getForces();
for (int i=0; i<stateForces.size();i++) { for (int i=0; i<stateForces.size();i++) {
forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]); forcesNode.createChildNode("Force").setDoubleProperty("x", stateForces[i][0]).setDoubleProperty("y", stateForces[i][1]).setDoubleProperty("z", stateForces[i][2]);
} }
} catch (const OpenMMException &) {
// do nothing
} }
} }
...@@ -113,83 +103,54 @@ void* StateProxy::deserialize(const SerializationNode& node) const { ...@@ -113,83 +103,54 @@ void* StateProxy::deserialize(const SerializationNode& node) const {
const SerializationNode& CVec = boxVectorsNode.getChildNode("C"); const SerializationNode& CVec = boxVectorsNode.getChildNode("C");
Vec3 outCVec(CVec.getDoubleProperty("x"),CVec.getDoubleProperty("y"),CVec.getDoubleProperty("z")); Vec3 outCVec(CVec.getDoubleProperty("x"),CVec.getDoubleProperty("y"),CVec.getDoubleProperty("z"));
int types = 0; int types = 0;
map<string, double> outStateParams; vector<int> arraySizes;
try { State::StateBuilder builder(outTime);
const SerializationNode& parametersNode = node.getChildNode("Parameters"); const vector<SerializationNode>& children = node.getChildren();
// inStateParams is really a <string,double> pair, where string is the name and double is the value for (int j = 0; j < (int) children.size(); j++) {
// but we want to avoid casting a string to a double and instead use the built in routines, const SerializationNode& child = children[j];
map<string, string> inStateParams = parametersNode.getProperties(); if (child.getName() == "Parameters") {
for (map<string, string>::const_iterator pit = inStateParams.begin(); pit != inStateParams.end(); pit++) { map<string, double> outStateParams;
outStateParams[pit->first] = parametersNode.getDoubleProperty(pit->first); // inStateParams is really a <string,double> pair, where string is the name and double is the value
// but we want to avoid casting a string to a double and instead use the built in routines,
map<string, string> inStateParams = child.getProperties();
for (map<string, string>::const_iterator pit = inStateParams.begin(); pit != inStateParams.end(); pit++) {
outStateParams[pit->first] = child.getDoubleProperty(pit->first);
}
builder.setParameters(outStateParams);
} }
types = types | State::Parameters; else if (child.getName() == "Energies") {
} catch (const OpenMMException &) { double potentialEnergy = child.getDoubleProperty("PotentialEnergy");
// do nothing double kineticEnergy = child.getDoubleProperty("KineticEnergy");
} builder.setEnergy(kineticEnergy, potentialEnergy);
double potentialEnergy;
double kineticEnergy;
try {
const SerializationNode& energiesNode = node.getChildNode("Energies");
potentialEnergy = energiesNode.getDoubleProperty("PotentialEnergy");
kineticEnergy = energiesNode.getDoubleProperty("KineticEnergy");
types = types | State::Energy;
} catch (const OpenMMException &) {
// do nothing
}
vector<Vec3> outPositions;
vector<Vec3> outVelocities;
vector<Vec3> outForces;
try {
const SerializationNode& positionsNode = node.getChildNode("Positions");
for (int i = 0; i < (int) positionsNode.getChildren().size(); i++) {
const SerializationNode& particle = positionsNode.getChildren()[i];
outPositions.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
} }
types = types | State::Positions; else if (child.getName() == "Positions") {
} catch (const OpenMMException &) { vector<Vec3> outPositions;
// do nothing for (int i = 0; i < (int) child.getChildren().size(); i++) {
} const SerializationNode& particle = child.getChildren()[i];
try { outPositions.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
const SerializationNode& velocitiesNode = node.getChildNode("Velocities"); }
for (int i = 0; i < (int) velocitiesNode.getChildren().size(); i++) { builder.setPositions(outPositions);
const SerializationNode& particle = velocitiesNode.getChildren()[i]; arraySizes.push_back(outPositions.size());
outVelocities.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
} }
types = types | State::Velocities; else if (child.getName() == "Velocities") {
} catch (const OpenMMException &) { vector<Vec3> outVelocities;
// do nothing for (int i = 0; i < (int) child.getChildren().size(); i++) {
} const SerializationNode& particle = child.getChildren()[i];
try { outVelocities.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
const SerializationNode& forcesNode = node.getChildNode("Forces"); }
for (int i = 0; i < (int) forcesNode.getChildren().size(); i++) { builder.setVelocities(outVelocities);
const SerializationNode& particle = forcesNode.getChildren()[i]; arraySizes.push_back(outVelocities.size());
outForces.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z"))); }
else if (child.getName() == "Forces") {
vector<Vec3> outForces;
for (int i = 0; i < (int) child.getChildren().size(); i++) {
const SerializationNode& particle = child.getChildren()[i];
outForces.push_back(Vec3(particle.getDoubleProperty("x"),particle.getDoubleProperty("y"),particle.getDoubleProperty("z")));
}
builder.setForces(outForces);
arraySizes.push_back(outForces.size());
} }
types = types | State::Forces;
} catch (const OpenMMException &) {
// do nothing
}
vector<int> arraySizes;
State::StateBuilder builder(outTime);
if (types & State::Positions) {
builder.setPositions(outPositions);
arraySizes.push_back(outPositions.size());
}
if (types & State::Velocities) {
builder.setVelocities(outVelocities);
arraySizes.push_back(outVelocities.size());
}
if (types & State::Forces) {
builder.setForces(outForces);
arraySizes.push_back(outForces.size());
}
if (types & State::Energy) {
builder.setEnergy(kineticEnergy, potentialEnergy);
}
if (types & State::Parameters) {
builder.setParameters(outStateParams);
} }
for (int i = 1; i < arraySizes.size(); i++) { for (int i = 1; i < arraySizes.size(); i++) {
if (arraySizes[i] != arraySizes[i-1]) { if (arraySizes[i] != arraySizes[i-1]) {
throw(OpenMMException("State Deserialization Particle Size Mismatch, check number of particles in Forces, Velocities, Positions!")); throw(OpenMMException("State Deserialization Particle Size Mismatch, check number of particles in Forces, Velocities, Positions!"));
......
...@@ -66,8 +66,8 @@ void testSerialization() { ...@@ -66,8 +66,8 @@ void testSerialization() {
system.addForce(nonbonded); system.addForce(nonbonded);
system.addForce(new AndersenThermostat(393.3, 19.3)); system.addForce(new AndersenThermostat(393.3, 19.3));
system.addForce(new MonteCarloBarostat(25, 393.3, 25)); system.addForce(new MonteCarloBarostat(25, 393.3, 25));
Integrator *intg = new LangevinIntegrator(300,79,0.002); LangevinIntegrator intg(300,79,0.002);
Context *ctxt = new Context(system, *intg); Context context(system, intg);
// Set positions, velocities, forces // Set positions, velocities, forces
vector<Vec3> positions; vector<Vec3> positions;
...@@ -79,11 +79,11 @@ void testSerialization() { ...@@ -79,11 +79,11 @@ void testSerialization() {
velocities.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2)); velocities.push_back(Vec3( ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2, ((float) rand()/(float) RAND_MAX)*6.2));
} }
ctxt->setPositions(positions); context.setPositions(positions);
ctxt->setVelocities(velocities); context.setVelocities(velocities);
// Serialize and then deserialize it. // Serialize and then deserialize it.
State s1 = ctxt->getState(State::Positions | State::Velocities | State::Forces | State::Energy | State::Parameters); State s1 = context.getState(State::Positions | State::Velocities | State::Forces | State::Energy | State::Parameters);
stringstream buffer; stringstream buffer;
XmlSerializer::serialize<State>(&s1, "State", buffer); XmlSerializer::serialize<State>(&s1, "State", buffer);
...@@ -132,6 +132,55 @@ void testSerialization() { ...@@ -132,6 +132,55 @@ void testSerialization() {
assert((it1->first).compare(it2->first) == 0); assert((it1->first).compare(it2->first) == 0);
ASSERT_EQUAL(it1->second, it2->second); ASSERT_EQUAL(it1->second, it2->second);
} }
delete copy;
// Now create a series of States that include only one type of information. Verify
// that serialization works correctly for them.
for (int types = 1; types <= 16; types *= 2) {
State s3 = context.getState(types);
stringstream buffer2;
XmlSerializer::serialize<State>(&s3, "State", buffer2);
copy = XmlSerializer::deserialize<State>(buffer2);
int foundTypes = 0;
try {
copy->getPositions();
foundTypes += State::Positions;
}
catch (...) {
// Ignore
}
try {
copy->getVelocities();
foundTypes += State::Velocities;
}
catch (...) {
// Ignore
}
try {
copy->getForces();
foundTypes += State::Forces;
}
catch (...) {
// Ignore
}
try {
copy->getPotentialEnergy();
foundTypes += State::Energy;
}
catch (...) {
// Ignore
}
try {
copy->getParameters();
foundTypes += State::Parameters;
}
catch (...) {
// Ignore
}
delete copy;
ASSERT_EQUAL(types, foundTypes);
}
} }
int main() { int main() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment