Commit c91f4f2e authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented serialization of virtual sites

parent 56766760
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include "openmm/serialization/SerializationNode.h" #include "openmm/serialization/SerializationNode.h"
#include "openmm/Force.h" #include "openmm/Force.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/VirtualSite.h"
#include <sstream> #include <sstream>
using namespace OpenMM; using namespace OpenMM;
...@@ -51,8 +52,23 @@ void SystemProxy::serialize(const void* object, SerializationNode& node) const { ...@@ -51,8 +52,23 @@ void SystemProxy::serialize(const void* object, SerializationNode& node) const {
box.createChildNode("B").setDoubleProperty("x", b[0]).setDoubleProperty("y", b[1]).setDoubleProperty("z", b[2]); box.createChildNode("B").setDoubleProperty("x", b[0]).setDoubleProperty("y", b[1]).setDoubleProperty("z", b[2]);
box.createChildNode("C").setDoubleProperty("x", c[0]).setDoubleProperty("y", c[1]).setDoubleProperty("z", c[2]); box.createChildNode("C").setDoubleProperty("x", c[0]).setDoubleProperty("y", c[1]).setDoubleProperty("z", c[2]);
SerializationNode& particles = node.createChildNode("Particles"); SerializationNode& particles = node.createChildNode("Particles");
for (int i = 0; i < system.getNumParticles(); i++) for (int i = 0; i < system.getNumParticles(); i++) {
particles.createChildNode("Particle").setDoubleProperty("mass", system.getParticleMass(i)); SerializationNode& particle = particles.createChildNode("Particle").setDoubleProperty("mass", system.getParticleMass(i));
if (system.isVirtualSite(i)) {
if (typeid(system.getVirtualSite(i)) == typeid(TwoParticleAverageSite)) {
const TwoParticleAverageSite& site = dynamic_cast<const TwoParticleAverageSite&>(system.getVirtualSite(i));
particle.createChildNode("TwoParticleAverageSite").setIntProperty("p1", site.getParticle(0)).setIntProperty("p2", site.getParticle(1)).setDoubleProperty("w1", site.getWeight(0)).setDoubleProperty("w2", site.getWeight(1));
}
else if (typeid(system.getVirtualSite(i)) == typeid(ThreeParticleAverageSite)) {
const ThreeParticleAverageSite& site = dynamic_cast<const ThreeParticleAverageSite&>(system.getVirtualSite(i));
particle.createChildNode("ThreeParticleAverageSite").setIntProperty("p1", site.getParticle(0)).setIntProperty("p2", site.getParticle(1)).setIntProperty("p3", site.getParticle(2)).setDoubleProperty("w1", site.getWeight(0)).setDoubleProperty("w2", site.getWeight(1)).setDoubleProperty("w3", site.getWeight(2));
}
else if (typeid(system.getVirtualSite(i)) == typeid(OutOfPlaneSite)) {
const OutOfPlaneSite& site = dynamic_cast<const OutOfPlaneSite&>(system.getVirtualSite(i));
particle.createChildNode("OutOfPlaneSite").setIntProperty("p1", site.getParticle(0)).setIntProperty("p2", site.getParticle(1)).setIntProperty("p3", site.getParticle(2)).setDoubleProperty("w12", site.getWeight12()).setDoubleProperty("w13", site.getWeight13()).setDoubleProperty("wc", site.getWeightCross());
}
}
}
SerializationNode& constraints = node.createChildNode("Constraints"); SerializationNode& constraints = node.createChildNode("Constraints");
for (int i = 0; i < system.getNumConstraints(); i++) { for (int i = 0; i < system.getNumConstraints(); i++) {
int particle1, particle2; int particle1, particle2;
...@@ -79,8 +95,18 @@ void* SystemProxy::deserialize(const SerializationNode& node) const { ...@@ -79,8 +95,18 @@ void* SystemProxy::deserialize(const SerializationNode& node) const {
Vec3 c(boxc.getDoubleProperty("x"), boxc.getDoubleProperty("y"), boxc.getDoubleProperty("z")); Vec3 c(boxc.getDoubleProperty("x"), boxc.getDoubleProperty("y"), boxc.getDoubleProperty("z"));
system->setDefaultPeriodicBoxVectors(a, b, c); system->setDefaultPeriodicBoxVectors(a, b, c);
const SerializationNode& particles = node.getChildNode("Particles"); const SerializationNode& particles = node.getChildNode("Particles");
for (int i = 0; i < (int) particles.getChildren().size(); i++) for (int i = 0; i < (int) particles.getChildren().size(); i++) {
system->addParticle(particles.getChildren()[i].getDoubleProperty("mass")); system->addParticle(particles.getChildren()[i].getDoubleProperty("mass"));
if (particles.getChildren()[i].getChildren().size() > 0) {
const SerializationNode& vsite = particles.getChildren()[i].getChildren()[0];
if (vsite.getName() == "TwoParticleAverageSite")
system->setVirtualSite(i, new TwoParticleAverageSite(vsite.getIntProperty("p1"), vsite.getIntProperty("p2"), vsite.getDoubleProperty("w1"), vsite.getDoubleProperty("w2")));
else if (vsite.getName() == "ThreeParticleAverageSite")
system->setVirtualSite(i, new ThreeParticleAverageSite(vsite.getIntProperty("p1"), vsite.getIntProperty("p2"), vsite.getIntProperty("p3"), vsite.getDoubleProperty("w1"), vsite.getDoubleProperty("w2"), vsite.getDoubleProperty("w3")));
else if (vsite.getName() == "OutOfPlaneSite")
system->setVirtualSite(i, new OutOfPlaneSite(vsite.getIntProperty("p1"), vsite.getIntProperty("p2"), vsite.getIntProperty("p3"), vsite.getDoubleProperty("w12"), vsite.getDoubleProperty("w13"), vsite.getDoubleProperty("wc")));
}
}
const SerializationNode& constraints = node.getChildNode("Constraints"); const SerializationNode& constraints = node.getChildNode("Constraints");
for (int i = 0; i < (int) constraints.getChildren().size(); i++) { for (int i = 0; i < (int) constraints.getChildren().size(); i++) {
const SerializationNode& constraint = constraints.getChildren()[i]; const SerializationNode& constraint = constraints.getChildren()[i];
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "openmm/HarmonicBondForce.h" #include "openmm/HarmonicBondForce.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/VirtualSite.h"
#include "openmm/serialization/XmlSerializer.h" #include "openmm/serialization/XmlSerializer.h"
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -45,10 +46,15 @@ void testSerialization() { ...@@ -45,10 +46,15 @@ void testSerialization() {
System system; System system;
for (int i = 0; i < 5; i++) for (int i = 0; i < 5; i++)
system.addParticle(0.1*i+1); system.addParticle(0.1*i+1);
for (int i = 0; i < 4; i++)
system.addParticle(0.0);
system.addConstraint(0, 1, 3.0); system.addConstraint(0, 1, 3.0);
system.addConstraint(1, 2, 2.5); system.addConstraint(1, 2, 2.5);
system.addConstraint(4, 1, 1.001); system.addConstraint(4, 1, 1.001);
system.setDefaultPeriodicBoxVectors(Vec3(5, 0, 0), Vec3(0, 4, 0), Vec3(0, 0, 1.5)); system.setDefaultPeriodicBoxVectors(Vec3(5, 0, 0), Vec3(0, 4, 0), Vec3(0, 0, 1.5));
system.setVirtualSite(5, new TwoParticleAverageSite(0, 1, 0.3, 0.7));
system.setVirtualSite(6, new ThreeParticleAverageSite(2, 4, 3, 0.5, 0.2, 0.3));
system.setVirtualSite(7, new OutOfPlaneSite(0, 3, 1, 0.1, 0.2, 0.5));
system.addForce(new HarmonicBondForce()); system.addForce(new HarmonicBondForce());
// Serialize and then deserialize it. // Serialize and then deserialize it.
...@@ -80,6 +86,27 @@ void testSerialization() { ...@@ -80,6 +86,27 @@ void testSerialization() {
ASSERT_EQUAL_VEC(a, a2, 0); ASSERT_EQUAL_VEC(a, a2, 0);
ASSERT_EQUAL_VEC(b, b2, 0); ASSERT_EQUAL_VEC(b, b2, 0);
ASSERT_EQUAL_VEC(c, c2, 0); ASSERT_EQUAL_VEC(c, c2, 0);
for (int i = 0; i < system.getNumParticles(); i++)
ASSERT_EQUAL(system.isVirtualSite(i), system2.isVirtualSite(i));
const TwoParticleAverageSite& site5 = dynamic_cast<const TwoParticleAverageSite&>(system2.getVirtualSite(5));
ASSERT_EQUAL(0, site5.getParticle(0));
ASSERT_EQUAL(1, site5.getParticle(1));
ASSERT_EQUAL(0.3, site5.getWeight(0));
ASSERT_EQUAL(0.7, site5.getWeight(1));
const ThreeParticleAverageSite& site6 = dynamic_cast<const ThreeParticleAverageSite&>(system2.getVirtualSite(6));
ASSERT_EQUAL(2, site6.getParticle(0));
ASSERT_EQUAL(4, site6.getParticle(1));
ASSERT_EQUAL(3, site6.getParticle(2));
ASSERT_EQUAL(0.5, site6.getWeight(0));
ASSERT_EQUAL(0.2, site6.getWeight(1));
ASSERT_EQUAL(0.3, site6.getWeight(2));
const OutOfPlaneSite& site7 = dynamic_cast<const OutOfPlaneSite&>(system2.getVirtualSite(7));
ASSERT_EQUAL(0, site7.getParticle(0));
ASSERT_EQUAL(3, site7.getParticle(1));
ASSERT_EQUAL(1, site7.getParticle(2));
ASSERT_EQUAL(0.1, site7.getWeight12());
ASSERT_EQUAL(0.2, site7.getWeight13());
ASSERT_EQUAL(0.5, site7.getWeightCross());
ASSERT_EQUAL(system.getNumForces(), system2.getNumForces()); ASSERT_EQUAL(system.getNumForces(), system2.getNumForces());
for (int i = 0; i < system.getNumForces(); i++) for (int i = 0; i < system.getNumForces(); i++)
ASSERT(typeid(system.getForce(i)) == typeid(system2.getForce(i))) ASSERT(typeid(system.getForce(i)) == typeid(system2.getForce(i)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment