Commit 436d542d authored by peastman's avatar peastman
Browse files

Merge pull request #410 from jchodera/debugging-group-interactions

Added missing serialization support for CustomNonbondedForce group interactions
parents 230819d1 fe5b499d
...@@ -49,6 +49,8 @@ void OPENMM_EXPORT throwException(const char* file, int line, const std::string& ...@@ -49,6 +49,8 @@ void OPENMM_EXPORT throwException(const char* file, int line, const std::string&
#define ASSERT(cond) {if (!(cond)) throwException(__FILE__, __LINE__, "");}; #define ASSERT(cond) {if (!(cond)) throwException(__FILE__, __LINE__, "");};
#define ASSERT_EQUAL(expected, found) {if (!((expected) == (found))) {std::stringstream details; details << "Expected "<<(expected)<<", found "<<(found); throwException(__FILE__, __LINE__, details.str());}}; #define ASSERT_EQUAL(expected, found) {if (!((expected) == (found))) {std::stringstream details; details << "Expected "<<(expected)<<", found "<<(found); throwException(__FILE__, __LINE__, details.str());}};
#define ASSERT_EQUAL_CONTAINERS(expected, found) {if (!((expected) == (found))) {std::stringstream details; details << "Containers not equal"; throwException(__FILE__, __LINE__, details.str());}};
#define ASSERT_EQUAL_TOL(expected, found, tol) {double _scale_ = std::abs(expected) > 1.0 ? std::abs(expected) : 1.0; if (!(std::abs((expected)-(found))/_scale_ <= (tol))) {std::stringstream details; details << "Expected "<<(expected)<<", found "<<(found); throwException(__FILE__, __LINE__, details.str());}}; #define ASSERT_EQUAL_TOL(expected, found, tol) {double _scale_ = std::abs(expected) > 1.0 ? std::abs(expected) : 1.0; if (!(std::abs((expected)-(found))/_scale_ <= (tol))) {std::stringstream details; details << "Expected "<<(expected)<<", found "<<(found); throwException(__FILE__, __LINE__, details.str());}};
......
...@@ -152,6 +152,28 @@ public: ...@@ -152,6 +152,28 @@ public:
* @param value the value to set for the property * @param value the value to set for the property
*/ */
SerializationNode& setIntProperty(const std::string& name, int value); SerializationNode& setIntProperty(const std::string& name, int value);
/**
* Get the property with a particular name, specified as an bool. If there is no property with
* the specified name, an exception is thrown.
*
* @param name the name of the property to get
*/
bool getBoolProperty(const std::string& name) const;
/**
* Get the property with a particular name, specified as a bool. If there is no property with
* the specified name, a default value is returned instead.
*
* @param name the name of the property to get
* @param defaultValue the value to return if the specified property does not exist
*/
bool getBoolProperty(const std::string& name, bool defaultValue) const;
/**
* Set the value of a property, specified as a bool.
*
* @param name the name of the property to set
* @param value the value to set for the property
*/
SerializationNode& setBoolProperty(const std::string& name, bool value);
/** /**
* Get the property with a particular name, specified as a double. If there is no property with * Get the property with a particular name, specified as a double. If there is no property with
* the specified name, an exception is thrown. * the specified name, an exception is thrown.
......
...@@ -47,6 +47,9 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode& ...@@ -47,6 +47,9 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode&
node.setStringProperty("energy", force.getEnergyFunction()); node.setStringProperty("energy", force.getEnergyFunction());
node.setIntProperty("method", (int) force.getNonbondedMethod()); node.setIntProperty("method", (int) force.getNonbondedMethod());
node.setDoubleProperty("cutoff", force.getCutoffDistance()); node.setDoubleProperty("cutoff", force.getCutoffDistance());
node.setBoolProperty("useSwitchingFunction", force.getUseSwitchingFunction());
node.setDoubleProperty("switchingDistance", force.getSwitchingDistance());
node.setBoolProperty("useLongRangeCorrection", force.getUseLongRangeCorrection());
SerializationNode& perParticleParams = node.createChildNode("PerParticleParameters"); SerializationNode& perParticleParams = node.createChildNode("PerParticleParameters");
for (int i = 0; i < force.getNumPerParticleParameters(); i++) { for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
perParticleParams.createChildNode("Parameter").setStringProperty("name", force.getPerParticleParameterName(i)); perParticleParams.createChildNode("Parameter").setStringProperty("name", force.getPerParticleParameterName(i));
...@@ -76,6 +79,20 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode& ...@@ -76,6 +79,20 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode&
SerializationNode& functions = node.createChildNode("Functions"); SerializationNode& functions = node.createChildNode("Functions");
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions.createChildNode("Function", &force.getTabulatedFunction(i)).setStringProperty("name", force.getTabulatedFunctionName(i)); functions.createChildNode("Function", &force.getTabulatedFunction(i)).setStringProperty("name", force.getTabulatedFunctionName(i));
SerializationNode& interactionGroups = node.createChildNode("InteractionGroups");
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
SerializationNode& interactionGroup = interactionGroups.createChildNode("InteractionGroup");
std::set<int> set1;
std::set<int> set2;
force.getInteractionGroupParameters(i, set1, set2);
SerializationNode& set1node = interactionGroup.createChildNode("Set1");
for (std::set<int>::iterator it = set1.begin(); it != set1.end(); ++it)
set1node.createChildNode("Particle").setIntProperty("index", *it);
SerializationNode& set2node = interactionGroup.createChildNode("Set2");
for (std::set<int>::iterator it = set2.begin(); it != set2.end(); ++it)
set2node.createChildNode("Particle").setIntProperty("index", *it);
}
} }
void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) const { void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) const {
...@@ -85,7 +102,10 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons ...@@ -85,7 +102,10 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons
try { try {
CustomNonbondedForce* force = new CustomNonbondedForce(node.getStringProperty("energy")); CustomNonbondedForce* force = new CustomNonbondedForce(node.getStringProperty("energy"));
force->setNonbondedMethod((CustomNonbondedForce::NonbondedMethod) node.getIntProperty("method")); force->setNonbondedMethod((CustomNonbondedForce::NonbondedMethod) node.getIntProperty("method"));
force->setCutoffDistance(node.getDoubleProperty("cutoff")); force->setCutoffDistance(node.getDoubleProperty("cutoff", 1.0));
force->setUseSwitchingFunction(node.getBoolProperty("useSwitchingFunction", false));
force->setSwitchingDistance(node.getDoubleProperty("switchingDistance", -1.0));
force->setUseLongRangeCorrection(node.getBoolProperty("useLongRangeCorrection", false));
const SerializationNode& perParticleParams = node.getChildNode("PerParticleParameters"); const SerializationNode& perParticleParams = node.getChildNode("PerParticleParameters");
for (int i = 0; i < (int) perParticleParams.getChildren().size(); i++) { for (int i = 0; i < (int) perParticleParams.getChildren().size(); i++) {
const SerializationNode& parameter = perParticleParams.getChildren()[i]; const SerializationNode& parameter = perParticleParams.getChildren()[i];
...@@ -121,7 +141,7 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons ...@@ -121,7 +141,7 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons
} }
else { else {
// This is an old file created before TabulatedFunction existed. // This is an old file created before TabulatedFunction existed.
const SerializationNode& valuesNode = function.getChildNode("Values"); const SerializationNode& valuesNode = function.getChildNode("Values");
vector<double> values; vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++) for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
...@@ -129,6 +149,26 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons ...@@ -129,6 +149,26 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons
force->addTabulatedFunction(function.getStringProperty("name"), new Continuous1DFunction(values, function.getDoubleProperty("min"), function.getDoubleProperty("max"))); force->addTabulatedFunction(function.getStringProperty("name"), new Continuous1DFunction(values, function.getDoubleProperty("min"), function.getDoubleProperty("max")));
} }
} }
// Catch exceptions if InteractionGroups node is missing, in order to give backwards compatibility.
try{
const SerializationNode& interactionGroups = node.getChildNode("InteractionGroups");
for (int i = 0; i < (int) interactionGroups.getChildren().size(); i++) {
const SerializationNode& interactionGroup = interactionGroups.getChildren()[i];
// Get set 1.
const SerializationNode& set1node = interactionGroup.getChildNode("Set1");
std::set<int> set1;
for (int j = 0; j < (int) set1node.getChildren().size(); j++)
set1.insert(set1node.getChildren()[j].getIntProperty("index"));
// Get set 2.
const SerializationNode& set2node = interactionGroup.getChildNode("Set2");
std::set<int> set2;
for (int j = 0; j < (int) set2node.getChildren().size(); j++)
set2.insert(set2node.getChildren()[j].getIntProperty("index"));
force->addInteractionGroup(set1, set2);
}
} catch (...) {
// do nothing to allow backwards-compatibility
}
return force; return force;
} }
catch (...) { catch (...) {
......
...@@ -46,9 +46,19 @@ void NonbondedForceProxy::serialize(const void* object, SerializationNode& node) ...@@ -46,9 +46,19 @@ void NonbondedForceProxy::serialize(const void* object, SerializationNode& node)
const NonbondedForce& force = *reinterpret_cast<const NonbondedForce*>(object); const NonbondedForce& force = *reinterpret_cast<const NonbondedForce*>(object);
node.setIntProperty("method", (int) force.getNonbondedMethod()); node.setIntProperty("method", (int) force.getNonbondedMethod());
node.setDoubleProperty("cutoff", force.getCutoffDistance()); node.setDoubleProperty("cutoff", force.getCutoffDistance());
node.setBoolProperty("useSwitchingFunction", force.getUseSwitchingFunction());
node.setDoubleProperty("switchingDistance", force.getSwitchingDistance());
node.setDoubleProperty("ewaldTolerance", force.getEwaldErrorTolerance()); node.setDoubleProperty("ewaldTolerance", force.getEwaldErrorTolerance());
node.setDoubleProperty("rfDielectric", force.getReactionFieldDielectric()); node.setDoubleProperty("rfDielectric", force.getReactionFieldDielectric());
node.setIntProperty("dispersionCorrection", force.getUseDispersionCorrection()); node.setIntProperty("dispersionCorrection", force.getUseDispersionCorrection());
double alpha;
int nx, ny, nz;
force.getPMEParameters(alpha, nx, ny, nz);
node.setDoubleProperty("alpha", alpha);
node.setIntProperty("nx", nx);
node.setIntProperty("ny", ny);
node.setIntProperty("nz", nz);
node.setIntProperty("recipForceGroup", force.getReciprocalSpaceForceGroup());
SerializationNode& particles = node.createChildNode("Particles"); SerializationNode& particles = node.createChildNode("Particles");
for (int i = 0; i < force.getNumParticles(); i++) { for (int i = 0; i < force.getNumParticles(); i++) {
double charge, sigma, epsilon; double charge, sigma, epsilon;
...@@ -70,10 +80,20 @@ void* NonbondedForceProxy::deserialize(const SerializationNode& node) const { ...@@ -70,10 +80,20 @@ void* NonbondedForceProxy::deserialize(const SerializationNode& node) const {
NonbondedForce* force = new NonbondedForce(); NonbondedForce* force = new NonbondedForce();
try { try {
force->setNonbondedMethod((NonbondedForce::NonbondedMethod) node.getIntProperty("method")); force->setNonbondedMethod((NonbondedForce::NonbondedMethod) node.getIntProperty("method"));
force->setCutoffDistance(node.getDoubleProperty("cutoff")); force->setCutoffDistance(node.getDoubleProperty("cutoff",1.0));
force->setEwaldErrorTolerance(node.getDoubleProperty("ewaldTolerance")); force->setUseSwitchingFunction(node.getBoolProperty("useSwitchingFunction",false));
force->setReactionFieldDielectric(node.getDoubleProperty("rfDielectric")); force->setSwitchingDistance(node.getDoubleProperty("switchingDistance",-1.0));
force->setUseDispersionCorrection(node.getIntProperty("dispersionCorrection")); force->setEwaldErrorTolerance(node.getDoubleProperty("ewaldTolerance",5e-4));
force->setReactionFieldDielectric(node.getDoubleProperty("rfDielectric",78.3));
force->setUseDispersionCorrection(node.getIntProperty("dispersionCorrection",true));
double alpha;
int nx, ny, nz;
alpha = node.getDoubleProperty("alpha",0.0);
nx = node.getIntProperty("nx",0);
ny = node.getIntProperty("ny",0);
nz = node.getIntProperty("nz",0);
force->setPMEParameters(alpha, nx, ny, nz);
force->setReciprocalSpaceForceGroup(node.getIntProperty("recipForceGroup",-1));
const SerializationNode& particles = node.getChildNode("Particles"); const SerializationNode& particles = node.getChildNode("Particles");
for (int i = 0; i < (int) particles.getChildren().size(); i++) { for (int i = 0; i < (int) particles.getChildren().size(); i++) {
const SerializationNode& particle = particles.getChildren()[i]; const SerializationNode& particle = particles.getChildren()[i];
......
...@@ -121,6 +121,32 @@ SerializationNode& SerializationNode::setIntProperty(const string& name, int val ...@@ -121,6 +121,32 @@ SerializationNode& SerializationNode::setIntProperty(const string& name, int val
return *this; return *this;
} }
bool SerializationNode::getBoolProperty(const string& name) const {
map<string, string>::const_iterator iter = properties.find(name);
if (iter == properties.end())
throw OpenMMException("Unknown property '"+name+"' in node '"+getName()+"'");
bool value;
stringstream(iter->second) >> value;
return value;
}
bool SerializationNode::getBoolProperty(const string& name, bool defaultValue) const {
map<string, string>::const_iterator iter = properties.find(name);
if (iter == properties.end())
return defaultValue;
bool value;
stringstream(iter->second) >> value;
return value;
}
SerializationNode& SerializationNode::setBoolProperty(const string& name, bool value) {
stringstream s;
s << value;
properties[name] = s.str();
return *this;
}
double SerializationNode::getDoubleProperty(const string& name) const { double SerializationNode::getDoubleProperty(const string& name) const {
map<string, string>::const_iterator iter = properties.find(name); map<string, string>::const_iterator iter = properties.find(name);
if (iter == properties.end()) if (iter == properties.end())
......
...@@ -43,6 +43,9 @@ void testSerialization() { ...@@ -43,6 +43,9 @@ void testSerialization() {
CustomNonbondedForce force("5*sin(x)^2+y*z"); CustomNonbondedForce force("5*sin(x)^2+y*z");
force.setNonbondedMethod(CustomNonbondedForce::CutoffPeriodic); force.setNonbondedMethod(CustomNonbondedForce::CutoffPeriodic);
force.setUseSwitchingFunction(true);
force.setUseLongRangeCorrection(true);
force.setSwitchingDistance(2.0);
force.setCutoffDistance(2.1); force.setCutoffDistance(2.1);
force.addGlobalParameter("x", 1.3); force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221); force.addGlobalParameter("y", 2.221);
...@@ -60,6 +63,10 @@ void testSerialization() { ...@@ -60,6 +63,10 @@ void testSerialization() {
for (int i = 0; i < 10; i++) for (int i = 0; i < 10; i++)
values[i] = sin((double) i); values[i] = sin((double) i);
force.addFunction("f", values, 0.5, 1.5); force.addFunction("f", values, 0.5, 1.5);
std::set<int> set1, set2;
set1.insert(0);
set2.insert(1);
force.addInteractionGroup(set1, set2);
// Serialize and then deserialize it. // Serialize and then deserialize it.
...@@ -73,6 +80,9 @@ void testSerialization() { ...@@ -73,6 +80,9 @@ void testSerialization() {
ASSERT_EQUAL(force.getEnergyFunction(), force2.getEnergyFunction()); ASSERT_EQUAL(force.getEnergyFunction(), force2.getEnergyFunction());
ASSERT_EQUAL(force.getNonbondedMethod(), force2.getNonbondedMethod()); ASSERT_EQUAL(force.getNonbondedMethod(), force2.getNonbondedMethod());
ASSERT_EQUAL(force.getCutoffDistance(), force2.getCutoffDistance()); ASSERT_EQUAL(force.getCutoffDistance(), force2.getCutoffDistance());
ASSERT_EQUAL(force.getSwitchingDistance(), force2.getSwitchingDistance());
ASSERT_EQUAL(force.getUseSwitchingFunction(), force2.getUseSwitchingFunction());
ASSERT_EQUAL(force.getUseLongRangeCorrection(), force2.getUseLongRangeCorrection());
ASSERT_EQUAL(force.getNumPerParticleParameters(), force2.getNumPerParticleParameters()); ASSERT_EQUAL(force.getNumPerParticleParameters(), force2.getNumPerParticleParameters());
for (int i = 0; i < force.getNumPerParticleParameters(); i++) for (int i = 0; i < force.getNumPerParticleParameters(); i++)
ASSERT_EQUAL(force.getPerParticleParameterName(i), force2.getPerParticleParameterName(i)); ASSERT_EQUAL(force.getPerParticleParameterName(i), force2.getPerParticleParameterName(i));
...@@ -111,6 +121,11 @@ void testSerialization() { ...@@ -111,6 +121,11 @@ void testSerialization() {
for (int j = 0; j < (int) val1.size(); j++) for (int j = 0; j < (int) val1.size(); j++)
ASSERT_EQUAL(val1[j], val2[j]); ASSERT_EQUAL(val1[j], val2[j]);
} }
ASSERT_EQUAL(force.getNumInteractionGroups(), force2.getNumInteractionGroups());
std::set<int> set1c, set2c;
force2.getInteractionGroupParameters(0, set1c, set2c);
ASSERT_EQUAL_CONTAINERS(set1, set1c);
ASSERT_EQUAL_CONTAINERS(set2, set2c);
} }
int main() { int main() {
......
...@@ -43,10 +43,15 @@ void testSerialization() { ...@@ -43,10 +43,15 @@ void testSerialization() {
NonbondedForce force; NonbondedForce force;
force.setNonbondedMethod(NonbondedForce::CutoffPeriodic); force.setNonbondedMethod(NonbondedForce::CutoffPeriodic);
force.setSwitchingDistance(1.5);
force.setUseSwitchingFunction(true);
force.setCutoffDistance(2.0); force.setCutoffDistance(2.0);
force.setEwaldErrorTolerance(1e-3); force.setEwaldErrorTolerance(1e-3);
force.setReactionFieldDielectric(50.0); force.setReactionFieldDielectric(50.0);
force.setUseDispersionCorrection(false); force.setUseDispersionCorrection(false);
double alpha = 0.5;
int nx = 3, ny = 5, nz = 7;
force.setPMEParameters(alpha, nx, ny, nz);
force.addParticle(1, 0.1, 0.01); force.addParticle(1, 0.1, 0.01);
force.addParticle(0.5, 0.2, 0.02); force.addParticle(0.5, 0.2, 0.02);
force.addParticle(-0.5, 0.3, 0.03); force.addParticle(-0.5, 0.3, 0.03);
...@@ -63,11 +68,20 @@ void testSerialization() { ...@@ -63,11 +68,20 @@ void testSerialization() {
NonbondedForce& force2 = *copy; NonbondedForce& force2 = *copy;
ASSERT_EQUAL(force.getNonbondedMethod(), force2.getNonbondedMethod()); ASSERT_EQUAL(force.getNonbondedMethod(), force2.getNonbondedMethod());
ASSERT_EQUAL(force.getSwitchingDistance(), force2.getSwitchingDistance());
ASSERT_EQUAL(force.getUseSwitchingFunction(), force2.getUseSwitchingFunction());
ASSERT_EQUAL(force.getCutoffDistance(), force2.getCutoffDistance()); ASSERT_EQUAL(force.getCutoffDistance(), force2.getCutoffDistance());
ASSERT_EQUAL(force.getEwaldErrorTolerance(), force2.getEwaldErrorTolerance()); ASSERT_EQUAL(force.getEwaldErrorTolerance(), force2.getEwaldErrorTolerance());
ASSERT_EQUAL(force.getReactionFieldDielectric(), force2.getReactionFieldDielectric()); ASSERT_EQUAL(force.getReactionFieldDielectric(), force2.getReactionFieldDielectric());
ASSERT_EQUAL(force.getUseDispersionCorrection(), force2.getUseDispersionCorrection()); ASSERT_EQUAL(force.getUseDispersionCorrection(), force2.getUseDispersionCorrection());
ASSERT_EQUAL(force.getNumParticles(), force2.getNumParticles()); ASSERT_EQUAL(force.getNumParticles(), force2.getNumParticles());
double alpha2;
int nx2, ny2, nz2;
force2.getPMEParameters(alpha2, nx2, ny2, nz2);
ASSERT_EQUAL(alpha, alpha2);
ASSERT_EQUAL(nx, nx2);
ASSERT_EQUAL(ny, ny2);
ASSERT_EQUAL(nz, nz2);
for (int i = 0; i < force.getNumParticles(); i++) { for (int i = 0; i < force.getNumParticles(); i++) {
double charge1, sigma1, epsilon1; double charge1, sigma1, epsilon1;
double charge2, sigma2, epsilon2; double charge2, sigma2, epsilon2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment