Commit a381a3ab authored by peastman's avatar peastman
Browse files

Merge branch 'master' into gayberne

parents 5ecc8e00 1f7866ad
...@@ -42,7 +42,7 @@ CustomCompoundBondForceProxy::CustomCompoundBondForceProxy() : SerializationProx ...@@ -42,7 +42,7 @@ CustomCompoundBondForceProxy::CustomCompoundBondForceProxy() : SerializationProx
} }
void CustomCompoundBondForceProxy::serialize(const void* object, SerializationNode& node) const { void CustomCompoundBondForceProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 2); node.setIntProperty("version", 3);
const CustomCompoundBondForce& force = *reinterpret_cast<const CustomCompoundBondForce*>(object); const CustomCompoundBondForce& force = *reinterpret_cast<const CustomCompoundBondForce*>(object);
node.setIntProperty("forceGroup", force.getForceGroup()); node.setIntProperty("forceGroup", force.getForceGroup());
node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions()); node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions());
...@@ -56,6 +56,10 @@ void CustomCompoundBondForceProxy::serialize(const void* object, SerializationNo ...@@ -56,6 +56,10 @@ void CustomCompoundBondForceProxy::serialize(const void* object, SerializationNo
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i)); globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i));
} }
SerializationNode& energyDerivs = node.createChildNode("EnergyParameterDerivatives");
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
energyDerivs.createChildNode("Parameter").setStringProperty("name", force.getEnergyParameterDerivativeName(i));
}
SerializationNode& bonds = node.createChildNode("Bonds"); SerializationNode& bonds = node.createChildNode("Bonds");
for (int i = 0; i < force.getNumBonds(); i++) { for (int i = 0; i < force.getNumBonds(); i++) {
vector<int> particles; vector<int> particles;
...@@ -82,7 +86,7 @@ void CustomCompoundBondForceProxy::serialize(const void* object, SerializationNo ...@@ -82,7 +86,7 @@ void CustomCompoundBondForceProxy::serialize(const void* object, SerializationNo
void* CustomCompoundBondForceProxy::deserialize(const SerializationNode& node) const { void* CustomCompoundBondForceProxy::deserialize(const SerializationNode& node) const {
int version = node.getIntProperty("version"); int version = node.getIntProperty("version");
if (version < 1 || version > 2) if (version < 1 || version > 3)
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
CustomCompoundBondForce* force = NULL; CustomCompoundBondForce* force = NULL;
try { try {
...@@ -100,6 +104,13 @@ void* CustomCompoundBondForceProxy::deserialize(const SerializationNode& node) c ...@@ -100,6 +104,13 @@ void* CustomCompoundBondForceProxy::deserialize(const SerializationNode& node) c
const SerializationNode& parameter = globalParams.getChildren()[i]; const SerializationNode& parameter = globalParams.getChildren()[i];
force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default")); force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default"));
} }
if (version > 2) {
const SerializationNode& energyDerivs = node.getChildNode("EnergyParameterDerivatives");
for (int i = 0; i < (int) energyDerivs.getChildren().size(); i++) {
const SerializationNode& parameter = energyDerivs.getChildren()[i];
force->addEnergyParameterDerivative(parameter.getStringProperty("name"));
}
}
const SerializationNode& bonds = node.getChildNode("Bonds"); const SerializationNode& bonds = node.getChildNode("Bonds");
vector<int> particles(force->getNumParticlesPerBond()); vector<int> particles(force->getNumParticlesPerBond());
vector<double> params(force->getNumPerBondParameters()); vector<double> params(force->getNumPerBondParameters());
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2014 Stanford University and the Authors. * * Portions copyright (c) 2010-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -42,7 +42,7 @@ CustomGBForceProxy::CustomGBForceProxy() : SerializationProxy("CustomGBForce") { ...@@ -42,7 +42,7 @@ CustomGBForceProxy::CustomGBForceProxy() : SerializationProxy("CustomGBForce") {
} }
void CustomGBForceProxy::serialize(const void* object, SerializationNode& node) const { void CustomGBForceProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1); node.setIntProperty("version", 2);
const CustomGBForce& force = *reinterpret_cast<const CustomGBForce*>(object); const CustomGBForce& force = *reinterpret_cast<const CustomGBForce*>(object);
node.setIntProperty("forceGroup", force.getForceGroup()); node.setIntProperty("forceGroup", force.getForceGroup());
node.setIntProperty("method", (int) force.getNonbondedMethod()); node.setIntProperty("method", (int) force.getNonbondedMethod());
...@@ -55,6 +55,10 @@ void CustomGBForceProxy::serialize(const void* object, SerializationNode& node) ...@@ -55,6 +55,10 @@ void CustomGBForceProxy::serialize(const void* object, SerializationNode& node)
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i)); globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i));
} }
SerializationNode& energyDerivs = node.createChildNode("EnergyParameterDerivatives");
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
energyDerivs.createChildNode("Parameter").setStringProperty("name", force.getEnergyParameterDerivativeName(i));
}
SerializationNode& computedValues = node.createChildNode("ComputedValues"); SerializationNode& computedValues = node.createChildNode("ComputedValues");
for (int i = 0; i < force.getNumComputedValues(); i++) { for (int i = 0; i < force.getNumComputedValues(); i++) {
string name, expression; string name, expression;
...@@ -93,7 +97,8 @@ void CustomGBForceProxy::serialize(const void* object, SerializationNode& node) ...@@ -93,7 +97,8 @@ void CustomGBForceProxy::serialize(const void* object, SerializationNode& node)
} }
void* CustomGBForceProxy::deserialize(const SerializationNode& node) const { void* CustomGBForceProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1) int version = node.getIntProperty("version");
if (version < 1 || version > 2)
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
CustomGBForce* force = NULL; CustomGBForce* force = NULL;
try { try {
...@@ -111,6 +116,13 @@ void* CustomGBForceProxy::deserialize(const SerializationNode& node) const { ...@@ -111,6 +116,13 @@ void* CustomGBForceProxy::deserialize(const SerializationNode& node) const {
const SerializationNode& parameter = globalParams.getChildren()[i]; const SerializationNode& parameter = globalParams.getChildren()[i];
force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default")); force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default"));
} }
if (version > 1) {
const SerializationNode& energyDerivs = node.getChildNode("EnergyParameterDerivatives");
for (int i = 0; i < (int) energyDerivs.getChildren().size(); i++) {
const SerializationNode& parameter = energyDerivs.getChildren()[i];
force->addEnergyParameterDerivative(parameter.getStringProperty("name"));
}
}
const SerializationNode& computedValues = node.getChildNode("ComputedValues"); const SerializationNode& computedValues = node.getChildNode("ComputedValues");
for (int i = 0; i < (int) computedValues.getChildren().size(); i++) { for (int i = 0; i < (int) computedValues.getChildren().size(); i++) {
const SerializationNode& value = computedValues.getChildren()[i]; const SerializationNode& value = computedValues.getChildren()[i];
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2014 Stanford University and the Authors. * * Portions copyright (c) 2010-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -42,7 +42,7 @@ CustomNonbondedForceProxy::CustomNonbondedForceProxy() : SerializationProxy("Cus ...@@ -42,7 +42,7 @@ CustomNonbondedForceProxy::CustomNonbondedForceProxy() : SerializationProxy("Cus
} }
void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode& node) const { void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1); node.setIntProperty("version", 2);
const CustomNonbondedForce& force = *reinterpret_cast<const CustomNonbondedForce*>(object); const CustomNonbondedForce& force = *reinterpret_cast<const CustomNonbondedForce*>(object);
node.setIntProperty("forceGroup", force.getForceGroup()); node.setIntProperty("forceGroup", force.getForceGroup());
node.setStringProperty("energy", force.getEnergyFunction()); node.setStringProperty("energy", force.getEnergyFunction());
...@@ -59,6 +59,10 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode& ...@@ -59,6 +59,10 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode&
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i)); globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i));
} }
SerializationNode& energyDerivs = node.createChildNode("EnergyParameterDerivatives");
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
energyDerivs.createChildNode("Parameter").setStringProperty("name", force.getEnergyParameterDerivativeName(i));
}
SerializationNode& particles = node.createChildNode("Particles"); SerializationNode& particles = node.createChildNode("Particles");
for (int i = 0; i < force.getNumParticles(); i++) { for (int i = 0; i < force.getNumParticles(); i++) {
vector<double> params; vector<double> params;
...@@ -97,7 +101,8 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode& ...@@ -97,7 +101,8 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode&
} }
void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) const { void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1) int version = node.getIntProperty("version");
if (version < 1 || version > 2)
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
CustomNonbondedForce* force = NULL; CustomNonbondedForce* force = NULL;
try { try {
...@@ -118,6 +123,13 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons ...@@ -118,6 +123,13 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons
const SerializationNode& parameter = globalParams.getChildren()[i]; const SerializationNode& parameter = globalParams.getChildren()[i];
force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default")); force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default"));
} }
if (version > 1) {
const SerializationNode& energyDerivs = node.getChildNode("EnergyParameterDerivatives");
for (int i = 0; i < (int) energyDerivs.getChildren().size(); i++) {
const SerializationNode& parameter = energyDerivs.getChildren()[i];
force->addEnergyParameterDerivative(parameter.getStringProperty("name"));
}
}
const SerializationNode& particles = node.getChildNode("Particles"); const SerializationNode& particles = node.getChildNode("Particles");
vector<double> params(force->getNumPerParticleParameters()); vector<double> params(force->getNumPerParticleParameters());
for (int i = 0; i < (int) particles.getChildren().size(); i++) { for (int i = 0; i < (int) particles.getChildren().size(); i++) {
......
...@@ -42,7 +42,7 @@ CustomTorsionForceProxy::CustomTorsionForceProxy() : SerializationProxy("CustomT ...@@ -42,7 +42,7 @@ CustomTorsionForceProxy::CustomTorsionForceProxy() : SerializationProxy("CustomT
} }
void CustomTorsionForceProxy::serialize(const void* object, SerializationNode& node) const { void CustomTorsionForceProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 2); node.setIntProperty("version", 3);
const CustomTorsionForce& force = *reinterpret_cast<const CustomTorsionForce*>(object); const CustomTorsionForce& force = *reinterpret_cast<const CustomTorsionForce*>(object);
node.setIntProperty("forceGroup", force.getForceGroup()); node.setIntProperty("forceGroup", force.getForceGroup());
node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions()); node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions());
...@@ -55,6 +55,10 @@ void CustomTorsionForceProxy::serialize(const void* object, SerializationNode& n ...@@ -55,6 +55,10 @@ void CustomTorsionForceProxy::serialize(const void* object, SerializationNode& n
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i)); globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i));
} }
SerializationNode& energyDerivs = node.createChildNode("EnergyParameterDerivatives");
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
energyDerivs.createChildNode("Parameter").setStringProperty("name", force.getEnergyParameterDerivativeName(i));
}
SerializationNode& torsions = node.createChildNode("Torsions"); SerializationNode& torsions = node.createChildNode("Torsions");
for (int i = 0; i < force.getNumTorsions(); i++) { for (int i = 0; i < force.getNumTorsions(); i++) {
int p1, p2, p3, p4; int p1, p2, p3, p4;
...@@ -72,7 +76,7 @@ void CustomTorsionForceProxy::serialize(const void* object, SerializationNode& n ...@@ -72,7 +76,7 @@ void CustomTorsionForceProxy::serialize(const void* object, SerializationNode& n
void* CustomTorsionForceProxy::deserialize(const SerializationNode& node) const { void* CustomTorsionForceProxy::deserialize(const SerializationNode& node) const {
int version = node.getIntProperty("version"); int version = node.getIntProperty("version");
if (version < 1 || version > 2) if (version < 1 || version > 3)
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
CustomTorsionForce* force = NULL; CustomTorsionForce* force = NULL;
try { try {
...@@ -90,6 +94,13 @@ void* CustomTorsionForceProxy::deserialize(const SerializationNode& node) const ...@@ -90,6 +94,13 @@ void* CustomTorsionForceProxy::deserialize(const SerializationNode& node) const
const SerializationNode& parameter = globalParams.getChildren()[i]; const SerializationNode& parameter = globalParams.getChildren()[i];
force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default")); force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default"));
} }
if (version > 2) {
const SerializationNode& energyDerivs = node.getChildNode("EnergyParameterDerivatives");
for (int i = 0; i < (int) energyDerivs.getChildren().size(); i++) {
const SerializationNode& parameter = energyDerivs.getChildren()[i];
force->addEnergyParameterDerivative(parameter.getStringProperty("name"));
}
}
const SerializationNode& torsions = node.getChildNode("Torsions"); const SerializationNode& torsions = node.getChildNode("Torsions");
vector<double> params(force->getNumPerTorsionParameters()); vector<double> params(force->getNumPerTorsionParameters());
for (int i = 0; i < (int) torsions.getChildren().size(); i++) { for (int i = 0; i < (int) torsions.getChildren().size(); i++) {
......
...@@ -46,6 +46,7 @@ void testSerialization() { ...@@ -46,6 +46,7 @@ void testSerialization() {
force.addGlobalParameter("x", 1.3); force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221); force.addGlobalParameter("y", 2.221);
force.addPerAngleParameter("z"); force.addPerAngleParameter("z");
force.addEnergyParameterDerivative("y");
vector<double> params(1); vector<double> params(1);
params[0] = 1.0; params[0] = 1.0;
force.addAngle(1, 2, 3, params); force.addAngle(1, 2, 3, params);
...@@ -74,6 +75,9 @@ void testSerialization() { ...@@ -74,6 +75,9 @@ void testSerialization() {
ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i)); ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i));
ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i)); ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i));
} }
ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i));
ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions()); ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions());
ASSERT_EQUAL(force.getNumAngles(), force2.getNumAngles()); ASSERT_EQUAL(force.getNumAngles(), force2.getNumAngles());
for (int i = 0; i < force.getNumAngles(); i++) { for (int i = 0; i < force.getNumAngles(); i++) {
......
...@@ -46,6 +46,7 @@ void testSerialization() { ...@@ -46,6 +46,7 @@ void testSerialization() {
force.addGlobalParameter("x", 1.3); force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221); force.addGlobalParameter("y", 2.221);
force.addPerBondParameter("z"); force.addPerBondParameter("z");
force.addEnergyParameterDerivative("y");
vector<double> params(1); vector<double> params(1);
params[0] = 1.0; params[0] = 1.0;
force.addBond(1, 2, params); force.addBond(1, 2, params);
...@@ -74,6 +75,9 @@ void testSerialization() { ...@@ -74,6 +75,9 @@ void testSerialization() {
ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i)); ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i));
ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i)); ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i));
} }
ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i));
ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions()); ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions());
ASSERT_EQUAL(force.getNumBonds(), force2.getNumBonds()); ASSERT_EQUAL(force.getNumBonds(), force2.getNumBonds());
for (int i = 0; i < force.getNumBonds(); i++) { for (int i = 0; i < force.getNumBonds(); i++) {
......
...@@ -46,6 +46,7 @@ void testSerialization() { ...@@ -46,6 +46,7 @@ void testSerialization() {
force.addGlobalParameter("x", 1.3); force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221); force.addGlobalParameter("y", 2.221);
force.addPerBondParameter("z"); force.addPerBondParameter("z");
force.addEnergyParameterDerivative("y");
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
vector<int> particles; vector<int> particles;
vector<double> weights; vector<double> weights;
...@@ -99,6 +100,9 @@ void testSerialization() { ...@@ -99,6 +100,9 @@ void testSerialization() {
ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i)); ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i));
ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i)); ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i));
} }
ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i));
ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions()); ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions());
ASSERT_EQUAL(force.getNumGroups(), force2.getNumGroups()); ASSERT_EQUAL(force.getNumGroups(), force2.getNumGroups());
for (int i = 0; i < force.getNumGroups(); i++) { for (int i = 0; i < force.getNumGroups(); i++) {
......
...@@ -46,6 +46,7 @@ void testSerialization() { ...@@ -46,6 +46,7 @@ void testSerialization() {
force.addGlobalParameter("x", 1.3); force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221); force.addGlobalParameter("y", 2.221);
force.addPerBondParameter("z"); force.addPerBondParameter("z");
force.addEnergyParameterDerivative("y");
vector<int> particles(3); vector<int> particles(3);
vector<double> params(1); vector<double> params(1);
particles[0] = 0; particles[0] = 0;
...@@ -89,6 +90,9 @@ void testSerialization() { ...@@ -89,6 +90,9 @@ void testSerialization() {
ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i)); ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i));
ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i)); ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i));
} }
ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i));
ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions()); ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions());
ASSERT_EQUAL(force.getNumBonds(), force2.getNumBonds()); ASSERT_EQUAL(force.getNumBonds(), force2.getNumBonds());
for (int i = 0; i < force.getNumBonds(); i++) { for (int i = 0; i < force.getNumBonds(); i++) {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2014 Stanford University and the Authors. * * Portions copyright (c) 2010-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -48,6 +48,7 @@ void testSerialization() { ...@@ -48,6 +48,7 @@ void testSerialization() {
force.addGlobalParameter("x", 1.3); force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221); force.addGlobalParameter("y", 2.221);
force.addPerParticleParameter("z"); force.addPerParticleParameter("z");
force.addEnergyParameterDerivative("y");
force.addComputedValue("a", "x+1", CustomGBForce::ParticlePairNoExclusions); force.addComputedValue("a", "x+1", CustomGBForce::ParticlePairNoExclusions);
force.addComputedValue("b", "y-1", CustomGBForce::SingleParticle); force.addComputedValue("b", "y-1", CustomGBForce::SingleParticle);
force.addEnergyTerm("a*b", CustomGBForce::SingleParticle); force.addEnergyTerm("a*b", CustomGBForce::SingleParticle);
...@@ -86,6 +87,9 @@ void testSerialization() { ...@@ -86,6 +87,9 @@ void testSerialization() {
ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i)); ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i));
ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i)); ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i));
} }
ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i));
ASSERT_EQUAL(force.getNumComputedValues(), force2.getNumComputedValues()); ASSERT_EQUAL(force.getNumComputedValues(), force2.getNumComputedValues());
for (int i = 0; i < force.getNumComputedValues(); i++) { for (int i = 0; i < force.getNumComputedValues(); i++) {
string name1, name2, expression1, expression2; string name1, name2, expression1, expression2;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2014 Stanford University and the Authors. * * Portions copyright (c) 2010-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -51,6 +51,7 @@ void testSerialization() { ...@@ -51,6 +51,7 @@ void testSerialization() {
force.addGlobalParameter("x", 1.3); force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221); force.addGlobalParameter("y", 2.221);
force.addPerParticleParameter("z"); force.addPerParticleParameter("z");
force.addEnergyParameterDerivative("y");
vector<double> params(1); vector<double> params(1);
params[0] = 1.0; params[0] = 1.0;
force.addParticle(params); force.addParticle(params);
...@@ -94,6 +95,9 @@ void testSerialization() { ...@@ -94,6 +95,9 @@ void testSerialization() {
ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i)); ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i));
ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i)); ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i));
} }
ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i));
ASSERT_EQUAL(force.getNumParticles(), force2.getNumParticles()); ASSERT_EQUAL(force.getNumParticles(), force2.getNumParticles());
for (int i = 0; i < force.getNumParticles(); i++) { for (int i = 0; i < force.getNumParticles(); i++) {
vector<double> params1, params2; vector<double> params1, params2;
......
...@@ -46,6 +46,7 @@ void testSerialization() { ...@@ -46,6 +46,7 @@ void testSerialization() {
force.addGlobalParameter("x", 1.3); force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221); force.addGlobalParameter("y", 2.221);
force.addPerTorsionParameter("z"); force.addPerTorsionParameter("z");
force.addEnergyParameterDerivative("y");
vector<double> params(1); vector<double> params(1);
params[0] = 1.0; params[0] = 1.0;
force.addTorsion(1, 2, 3, 4, params); force.addTorsion(1, 2, 3, 4, params);
...@@ -74,6 +75,9 @@ void testSerialization() { ...@@ -74,6 +75,9 @@ void testSerialization() {
ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i)); ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i));
ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i)); ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i));
} }
ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i));
ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions()); ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions());
ASSERT_EQUAL(force.getNumTorsions(), force2.getNumTorsions()); ASSERT_EQUAL(force.getNumTorsions(), force2.getNumTorsions());
for (int i = 0; i < force.getNumTorsions(); i++) { for (int i = 0; i < force.getNumTorsions(); i++) {
......
...@@ -181,6 +181,41 @@ void testPeriodic() { ...@@ -181,6 +181,41 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(0.5*1.1*(M_PI/6)*(M_PI/6), state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(0.5*1.1*(M_PI/6)*(M_PI/6), state.getPotentialEnergy(), TOL);
} }
void testEnergyParameterDerivatives() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomAngleForce* angles = new CustomAngleForce("k*(theta-theta0)^2");
angles->addGlobalParameter("theta0", 0.0);
angles->addGlobalParameter("k", 0.0);
angles->addEnergyParameterDerivative("theta0");
angles->addEnergyParameterDerivative("k");
vector<double> parameters;
angles->addAngle(0, 1, 2, parameters);
system.addForce(angles);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 2, 0);
positions[1] = Vec3(0, 0, 0);
positions[2] = Vec3(1, 1, 0);
context.setPositions(positions);
double theta = M_PI/4;
for (int i = 0; i < 10; i++) {
double theta0 = 0.1*i;
double k = 10-i;
context.setParameter("theta0", theta0);
context.setParameter("k", k);
State state = context.getState(State::ParameterDerivatives);
map<string, double> derivs = state.getEnergyParameterDerivatives();
double dEdtheta0 = -2*k*(theta-theta0);
double dEdk = (theta-theta0)*(theta-theta0);
ASSERT_EQUAL_TOL(dEdtheta0, derivs["theta0"], 1e-5);
ASSERT_EQUAL_TOL(dEdk, derivs["k"], 1e-5);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -189,6 +224,7 @@ int main(int argc, char* argv[]) { ...@@ -189,6 +224,7 @@ int main(int argc, char* argv[]) {
testAngles(); testAngles();
testIllegalVariable(); testIllegalVariable();
testPeriodic(); testPeriodic();
testEnergyParameterDerivatives();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. * * Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -178,6 +178,41 @@ void testPeriodic() { ...@@ -178,6 +178,41 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(0.5*0.8*0.9*0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(0.5*0.8*0.9*0.9, state.getPotentialEnergy(), TOL);
} }
void testEnergyParameterDerivatives() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomBondForce* bonds = new CustomBondForce("k*(r-r0)^2");
bonds->addGlobalParameter("r0", 0.0);
bonds->addGlobalParameter("k", 0.0);
bonds->addEnergyParameterDerivative("k");
bonds->addEnergyParameterDerivative("r0");
vector<double> parameters;
bonds->addBond(0, 1, parameters);
bonds->addBond(1, 2, parameters);
system.addForce(bonds);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 2, 0);
positions[1] = Vec3(0, 0, 0);
positions[2] = Vec3(1, 0, 0);
context.setPositions(positions);
for (int i = 0; i < 10; i++) {
double r0 = 0.1*i;
double k = 10-i;
context.setParameter("r0", r0);
context.setParameter("k", k);
State state = context.getState(State::ParameterDerivatives);
map<string, double> derivs = state.getEnergyParameterDerivatives();
double dEdr0 = -2*k*((2-r0)+(1-r0));
double dEdk = (2-r0)*(2-r0) + (1-r0)*(1-r0);
ASSERT_EQUAL_TOL(dEdr0, derivs["r0"], 1e-5);
ASSERT_EQUAL_TOL(dEdk, derivs["k"], 1e-5);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -187,6 +222,7 @@ int main(int argc, char* argv[]) { ...@@ -187,6 +222,7 @@ int main(int argc, char* argv[]) {
testManyParameters(); testManyParameters();
testIllegalVariable(); testIllegalVariable();
testPeriodic(); testPeriodic();
testEnergyParameterDerivatives();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -333,6 +333,65 @@ void testPeriodic() { ...@@ -333,6 +333,65 @@ void testPeriodic() {
ASSERT_EQUAL_VEC(Vec3(2*0.5*(5.0/12.0), 0, 0), state.getForces()[4], TOL); ASSERT_EQUAL_VEC(Vec3(2*0.5*(5.0/12.0), 0, 0), state.getForces()[4], TOL);
} }
void testEnergyParameterDerivatives() {
System system;
system.addParticle(1.0);
system.addParticle(2.0);
system.addParticle(3.0);
system.addParticle(4.0);
system.addParticle(5.0);
CustomCentroidBondForce* force = new CustomCentroidBondForce(2, "k*(distance(g1,g2)-r0)^2");
force->addGlobalParameter("r0", 0.0);
force->addGlobalParameter("k", 0.0);
force->addEnergyParameterDerivative("r0");
force->addEnergyParameterDerivative("k");
vector<int> particles1;
particles1.push_back(0);
particles1.push_back(1);
vector<int> particles2;
particles2.push_back(2);
particles2.push_back(3);
particles2.push_back(4);
force->addGroup(particles1);
force->addGroup(particles2);
vector<int> groups;
groups.push_back(0);
groups.push_back(1);
vector<double> parameters;
force->addBond(groups, parameters);
system.addForce(force);
// The center of mass of group 0 is (1.5, 0, 0).
vector<Vec3> positions(5);
positions[0] = Vec3(2.5, 0, 0);
positions[1] = Vec3(1, 0, 0);
// The center of mass of group 1 is (-1, 0, 0).
positions[2] = Vec3(-6, 0, 0);
positions[3] = Vec3(-1, 0, 0);
positions[4] = Vec3(2, 0, 0);
// Check the derivatives.
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
for (int i = 0; i < 10; i++) {
double r0 = 0.1*i;
double k = 10-i;
context.setParameter("r0", r0);
context.setParameter("k", k);
State state = context.getState(State::ParameterDerivatives);
map<string, double> derivs = state.getEnergyParameterDerivatives();
double dEdr0 = -2*k*(2.5-r0);
double dEdk = (2.5-r0)*(2.5-r0);
ASSERT_EQUAL_TOL(dEdr0, derivs["r0"], 1e-5);
ASSERT_EQUAL_TOL(dEdk, derivs["k"], 1e-5);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -343,6 +402,7 @@ int main(int argc, char* argv[]) { ...@@ -343,6 +402,7 @@ int main(int argc, char* argv[]) {
testCustomWeights(); testCustomWeights();
testIllegalVariable(); testIllegalVariable();
testPeriodic(); testPeriodic();
testEnergyParameterDerivatives();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -171,7 +171,7 @@ void testContinuous2DFunction() { ...@@ -171,7 +171,7 @@ void testContinuous2DFunction() {
const double xmin = 0.4; const double xmin = 0.4;
const double xmax = 1.1; const double xmax = 1.1;
const double ymin = 0.0; const double ymin = 0.0;
const double ymax = 0.9; const double ymax = 0.95;
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
VerletIntegrator integrator(0.01); VerletIntegrator integrator(0.01);
...@@ -218,7 +218,7 @@ void testContinuous3DFunction() { ...@@ -218,7 +218,7 @@ void testContinuous3DFunction() {
const double ymin = 2.0; const double ymin = 2.0;
const double ymax = 2.9; const double ymax = 2.9;
const double zmin = 0.2; const double zmin = 0.2;
const double zmax = 1.3; const double zmax = 1.35;
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
VerletIntegrator integrator(0.01); VerletIntegrator integrator(0.01);
...@@ -406,6 +406,48 @@ void testPeriodic() { ...@@ -406,6 +406,48 @@ void testPeriodic() {
} }
} }
void testEnergyParameterDerivatives() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomCompoundBondForce* custom = new CustomCompoundBondForce(4, "k*(dihedral(p1,p2,p3,p4)-theta0)^2");
custom->addGlobalParameter("theta0", 0.0);
custom->addGlobalParameter("k", 0.0);
custom->addEnergyParameterDerivative("theta0");
custom->addEnergyParameterDerivative("k");
vector<int> particles(4);
particles[0] = 0;
particles[1] = 1;
particles[2] = 2;
particles[3] = 3;
vector<double> parameters;
custom->addBond(particles, parameters);
system.addForce(custom);
Context context(system, integrator, platform);
vector<Vec3> positions(4);
positions[0] = Vec3(0, 2, 0);
positions[1] = Vec3(0, 0, 0);
positions[2] = Vec3(1, 0, 0);
positions[3] = Vec3(1, 1, 1);
context.setPositions(positions);
double theta = M_PI/4;
for (int i = 0; i < 10; i++) {
double theta0 = 0.1*i;
double k = 10-i;
context.setParameter("theta0", theta0);
context.setParameter("k", k);
State state = context.getState(State::ParameterDerivatives);
map<string, double> derivs = state.getEnergyParameterDerivatives();
double dEdtheta0 = -2*k*(theta-theta0);
double dEdk = (theta-theta0)*(theta-theta0);
ASSERT_EQUAL_TOL(dEdtheta0, derivs["theta0"], 1e-5);
ASSERT_EQUAL_TOL(dEdk, derivs["k"], 1e-5);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -418,6 +460,7 @@ int main(int argc, char* argv[]) { ...@@ -418,6 +460,7 @@ int main(int argc, char* argv[]) {
testMultipleBonds(); testMultipleBonds();
testIllegalVariable(); testIllegalVariable();
testPeriodic(); testPeriodic();
testEnergyParameterDerivatives();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. * * Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -488,6 +488,54 @@ void testIllegalVariable() { ...@@ -488,6 +488,54 @@ void testIllegalVariable() {
ASSERT(threwException); ASSERT(threwException);
} }
void testEnergyParameterDerivatives() {
// Create a box of particles.
const int numParticles = 40;
const int numParameters = 4;
const double boxSize = 2.0;
const double delta = 1e-3;
const string paramNames[] = {"A", "B", "C", "D"};
const double paramValues[] = {0.8, 2.1, 3.2, 1.3};
System system;
system.setDefaultPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
CustomGBForce* force = new CustomGBForce();
system.addForce(force);
force->addComputedValue("a", "0.5*(r-A)^2", CustomGBForce::ParticlePair);
force->addComputedValue("b", "a+B", CustomGBForce::SingleParticle);
force->addEnergyTerm("C*(a1+b1+a2+b2+r)^0.8", CustomGBForce::ParticlePair);
force->addEnergyTerm("(D-B)*b", CustomGBForce::SingleParticle);
for (int i = 0; i < numParameters; i++)
force->addGlobalParameter(paramNames[i], paramValues[i]);
for (int i = numParameters-1; i >= 0; i--)
force->addEnergyParameterDerivative(paramNames[i]);
force->setNonbondedMethod(CustomGBForce::CutoffPeriodic);
force->setCutoffDistance(1.0);
vector<Vec3> positions;
vector<double> parameters;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
force->addParticle(parameters);
positions.push_back(Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))*boxSize);
}
// Compute the energy derivative and compare it to a finite difference approximation.
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
map<string, double> derivs = context.getState(State::ParameterDerivatives).getEnergyParameterDerivatives();
for (int i = 0; i < numParameters; i++) {
context.setParameter(paramNames[i], paramValues[i]+delta);
double energy1 = context.getState(State::Energy).getPotentialEnergy();
context.setParameter(paramNames[i], paramValues[i]-delta);
double energy2 = context.getState(State::Energy).getPotentialEnergy();
ASSERT_EQUAL_TOL((energy1-energy2)/(2*delta), derivs[paramNames[i]], 5e-3);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -502,6 +550,7 @@ int main(int argc, char* argv[]) { ...@@ -502,6 +550,7 @@ int main(int argc, char* argv[]) {
testPositionDependence(); testPositionDependence();
testExclusions(); testExclusions();
testIllegalVariable(); testIllegalVariable();
testEnergyParameterDerivatives();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. * * Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -29,15 +29,21 @@ ...@@ -29,15 +29,21 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#ifdef WIN32
#define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/AndersenThermostat.h" #include "openmm/AndersenThermostat.h"
#include "openmm/CustomAngleForce.h"
#include "openmm/CustomBondForce.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/HarmonicBondForce.h" #include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h" #include "openmm/NonbondedForce.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/CustomIntegrator.h"
#include "SimTKOpenMMRealType.h" #include "SimTKOpenMMRealType.h"
#include "sfmt/SFMT.h" #include "sfmt/SFMT.h"
#include <cmath>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -770,6 +776,88 @@ void testChangingGlobal() { ...@@ -770,6 +776,88 @@ void testChangingGlobal() {
} }
} }
/**
* Test steps that depend on derivatives of the energy with respect to parameters.
*/
void testEnergyParameterDerivatives() {
System system;
for (int i = 0; i < 3; i++)
system.addParticle(1.0);
// Create some custom forces that depend on parameters.
CustomBondForce* bonds = new CustomBondForce("K*(A*r-r0)^2");
system.addForce(bonds);
bonds->addGlobalParameter("K", 2.0);
bonds->addGlobalParameter("A", 1.0);
bonds->addGlobalParameter("r0", 1.5);
bonds->addEnergyParameterDerivative("K");
bonds->addEnergyParameterDerivative("r0");
bonds->addBond(0, 1);
bonds->setForceGroup(0);
CustomAngleForce* angles = new CustomAngleForce("K*(B*theta-theta0)^2");
system.addForce(angles);
angles->addGlobalParameter("K", 2.0);
angles->addGlobalParameter("B", 1.0);
angles->addGlobalParameter("theta0", M_PI/3);
angles->addEnergyParameterDerivative("K");
angles->addEnergyParameterDerivative("theta0");
angles->addAngle(0, 1, 2);
angles->setForceGroup(1);
// Create an integrator that records parameter derivatives.
CustomIntegrator integrator(0.1);
integrator.addGlobalVariable("dEdK", 0.0);
integrator.addGlobalVariable("dEdr0", 0.0);
integrator.addPerDofVariable("dEdtheta0", 0.0);
integrator.addGlobalVariable("dEdK_0", 0.0);
integrator.addPerDofVariable("dEdr0_0", 0.0);
integrator.addGlobalVariable("dEdtheta0_0", 0.0);
integrator.addPerDofVariable("dEdK_1", 0.0);
integrator.addGlobalVariable("dEdr0_1", 0.0);
integrator.addGlobalVariable("dEdtheta0_1", 0.0);
integrator.addComputeGlobal("dEdK", "deriv(energy, K)");
integrator.addComputeGlobal("dEdr0", "deriv(energy, r0)");
integrator.addComputePerDof("dEdtheta0", "deriv(energy, theta0)");
integrator.addComputeGlobal("dEdK_0", "deriv(energy0, K)");
integrator.addComputePerDof("dEdr0_0", "deriv(energy0, r0)");
integrator.addComputeGlobal("dEdtheta0_0", "deriv(energy0, theta0)");
integrator.addComputePerDof("dEdK_1", "deriv(energy1, K)");
integrator.addComputeGlobal("dEdr0_1", "deriv(energy1, r0)");
integrator.addComputeGlobal("dEdtheta0_1", "deriv(energy1, theta0)");
// Create a Context.
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 1, 0);
positions[1] = Vec3(0, 0, 0);
positions[2] = Vec3(1, 0, 0);
context.setPositions(positions);
// Check the results.
integrator.step(1);
vector<Vec3> values;
double dEdK_0 = (1.0-1.5)*(1.0-1.5);
double dEdK_1 = (M_PI/2-M_PI/3)*(M_PI/2-M_PI/3);
ASSERT_EQUAL_TOL(dEdK_0, integrator.getGlobalVariableByName("dEdK_0"), 1e-5);
integrator.getPerDofVariableByName("dEdK_1", values);
ASSERT_EQUAL_TOL(dEdK_1, values[0][2], 1e-5);
ASSERT_EQUAL_TOL(dEdK_0+dEdK_1, integrator.getGlobalVariableByName("dEdK"), 1e-5);
double dEdr0 = -2.0*2.0*(1.0-1.5);
integrator.getPerDofVariableByName("dEdr0_0", values);
ASSERT_EQUAL_TOL(dEdr0, values[1][0], 1e-5);
ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdr0_1"), 1e-5);
ASSERT_EQUAL_TOL(dEdr0, integrator.getGlobalVariableByName("dEdr0"), 1e-5);
double dEdtheta0 = -2.0*2.0*(M_PI/2-M_PI/3);
ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdtheta0_0"), 1e-5);
ASSERT_EQUAL_TOL(dEdtheta0, integrator.getGlobalVariableByName("dEdtheta0_1"), 1e-5);
integrator.getPerDofVariableByName("dEdtheta0", values);
ASSERT_EQUAL_TOL(dEdtheta0, values[2][1], 1e-5);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -790,6 +878,7 @@ int main(int argc, char* argv[]) { ...@@ -790,6 +878,7 @@ int main(int argc, char* argv[]) {
testIfBlock(); testIfBlock();
testWhileBlock(); testWhileBlock();
testChangingGlobal(); testChangingGlobal();
testEnergyParameterDerivatives();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. * * Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -1041,6 +1041,84 @@ void testIllegalVariable() { ...@@ -1041,6 +1041,84 @@ void testIllegalVariable() {
ASSERT(threwException); ASSERT(threwException);
} }
void testEnergyParameterDerivatives() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* nonbonded = new CustomNonbondedForce("k*(r-r0)^2");
nonbonded->addGlobalParameter("r0", 0.0);
nonbonded->addGlobalParameter("k", 0.0);
nonbonded->addEnergyParameterDerivative("k");
nonbonded->addEnergyParameterDerivative("r0");
vector<double> parameters;
nonbonded->addParticle(parameters);
nonbonded->addParticle(parameters);
nonbonded->addParticle(parameters);
nonbonded->addExclusion(0, 2);
system.addForce(nonbonded);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 2, 0);
positions[1] = Vec3(0, 0, 0);
positions[2] = Vec3(1, 0, 0);
context.setPositions(positions);
for (int i = 0; i < 10; i++) {
double r0 = 0.1*i;
double k = 10-i;
context.setParameter("r0", r0);
context.setParameter("k", k);
State state = context.getState(State::ParameterDerivatives);
map<string, double> derivs = state.getEnergyParameterDerivatives();
double dEdr0 = -2*k*((2-r0)+(1-r0));
double dEdk = (2-r0)*(2-r0) + (1-r0)*(1-r0);
ASSERT_EQUAL_TOL(dEdr0, derivs["r0"], 1e-5);
ASSERT_EQUAL_TOL(dEdk, derivs["k"], 1e-5);
}
}
void testEnergyParameterDerivatives2() {
// Create a box of particles.
const int numParticles = 30;
const double boxSize = 2.0;
const double a = 1.0;
const double delta = 1e-3;
System system;
system.setDefaultPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
CustomNonbondedForce* nonbonded = new CustomNonbondedForce("(r+a)^-4");
system.addForce(nonbonded);
nonbonded->addGlobalParameter("a", a);
nonbonded->addEnergyParameterDerivative("a");
nonbonded->setNonbondedMethod(CustomNonbondedForce::CutoffPeriodic);
nonbonded->setCutoffDistance(1.0);
nonbonded->setSwitchingDistance(0.9);
nonbonded->setUseSwitchingFunction(true);
nonbonded->setUseLongRangeCorrection(true);
vector<Vec3> positions;
vector<double> parameters;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
nonbonded->addParticle(parameters);
positions.push_back(Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))*boxSize);
}
// Compute the energy derivative and compare it to a finite difference approximation.
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
map<string, double> derivs = context.getState(State::ParameterDerivatives).getEnergyParameterDerivatives();
context.setParameter("a", a+delta);
double energy1 = context.getState(State::Energy).getPotentialEnergy();
context.setParameter("a", a-delta);
double energy2 = context.getState(State::Energy).getPotentialEnergy();
ASSERT_EQUAL_TOL((energy1-energy2)/(2*delta), derivs["a"], 1e-4);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -1067,6 +1145,8 @@ int main(int argc, char* argv[]) { ...@@ -1067,6 +1145,8 @@ int main(int argc, char* argv[]) {
testInteractionGroupTabulatedFunction(); testInteractionGroupTabulatedFunction();
testMultipleCutoffs(); testMultipleCutoffs();
testIllegalVariable(); testIllegalVariable();
testEnergyParameterDerivatives();
testEnergyParameterDerivatives2();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -222,6 +222,43 @@ void testPeriodic() { ...@@ -222,6 +222,43 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.1*(1+std::cos(2*M_PI/3)), state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.1*(1+std::cos(2*M_PI/3)), state.getPotentialEnergy(), TOL);
} }
void testEnergyParameterDerivatives() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomTorsionForce* torsions = new CustomTorsionForce("k*(theta-theta0)^2");
torsions->addGlobalParameter("theta0", 0.0);
torsions->addGlobalParameter("k", 0.0);
torsions->addEnergyParameterDerivative("theta0");
torsions->addEnergyParameterDerivative("k");
vector<double> parameters;
torsions->addTorsion(0, 1, 2, 3, parameters);
system.addForce(torsions);
Context context(system, integrator, platform);
vector<Vec3> positions(4);
positions[0] = Vec3(0, 2, 0);
positions[1] = Vec3(0, 0, 0);
positions[2] = Vec3(1, 0, 0);
positions[3] = Vec3(1, 1, 1);
context.setPositions(positions);
double theta = M_PI/4;
for (int i = 0; i < 10; i++) {
double theta0 = 0.1*i;
double k = 10-i;
context.setParameter("theta0", theta0);
context.setParameter("k", k);
State state = context.getState(State::ParameterDerivatives);
map<string, double> derivs = state.getEnergyParameterDerivatives();
double dEdtheta0 = -2*k*(theta-theta0);
double dEdk = (theta-theta0)*(theta-theta0);
ASSERT_EQUAL_TOL(dEdtheta0, derivs["theta0"], 1e-5);
ASSERT_EQUAL_TOL(dEdk, derivs["k"], 1e-5);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -231,6 +268,7 @@ int main(int argc, char* argv[]) { ...@@ -231,6 +268,7 @@ int main(int argc, char* argv[]) {
testRange(); testRange();
testIllegalVariable(); testIllegalVariable();
testPeriodic(); testPeriodic();
testEnergyParameterDerivatives();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -108,10 +108,9 @@ else(SWIG_EXECUTABLE) ...@@ -108,10 +108,9 @@ else(SWIG_EXECUTABLE)
set(SWIG_VERSION "0.0.0" CACHE STRING "Swig version" FORCE) set(SWIG_VERSION "0.0.0" CACHE STRING "Swig version" FORCE)
endif(SWIG_EXECUTABLE) endif(SWIG_EXECUTABLE)
# Enforce swig version # Enforce swig version
string(COMPARE LESS "${SWIG_VERSION}" "3.0.5" SWIG_VERSION_ERROR) if(SWIG_VERSION VERSION_LESS "3.0.5")
if(SWIG_VERSION_ERROR) message(SEND_ERROR "Swig version must be 3.0.5 or greater! (You have ${SWIG_VERSION})")
message("Swig version must be 3.0.5 or greater! (You have ${SWIG_VERSION})") endif(SWIG_VERSION VERSION_LESS "3.0.5")
endif(SWIG_VERSION_ERROR)
find_package(Doxygen REQUIRED) find_package(Doxygen REQUIRED)
mark_as_advanced(CLEAR DOXYGEN_EXECUTABLE) mark_as_advanced(CLEAR DOXYGEN_EXECUTABLE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment