"plugins/vscode:/vscode.git/clone" did not exist on "ad1b1ad1ad0d30b7148770d2e1a2c0d187a6ef2c"
Commit 4e1e1b11 authored by Peter Eastman's avatar Peter Eastman
Browse files

Custom functions are now represented by natural splines

parent 06a98e93
...@@ -195,7 +195,7 @@ void testCustomFunctions() { ...@@ -195,7 +195,7 @@ void testCustomFunctions() {
vector<double> function(2); vector<double> function(2);
function[0] = 0; function[0] = 0;
function[1] = 1; function[1] = 1;
custom->addFunction("foo", function, 0, 10, true); custom->addFunction("foo", function, 0, 10);
system.addForce(custom); system.addForce(custom);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(3); vector<Vec3> positions(3);
......
...@@ -203,7 +203,7 @@ void testPeriodic() { ...@@ -203,7 +203,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction(bool interpolating) { void testTabulatedFunction() {
ReferencePlatform platform; ReferencePlatform platform;
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -215,7 +215,7 @@ void testTabulatedFunction(bool interpolating) { ...@@ -215,7 +215,7 @@ void testTabulatedFunction(bool interpolating) {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0, interpolating); forceField->addFunction("fn", table, 1.0, 6.0);
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -233,6 +233,14 @@ void testTabulatedFunction(bool interpolating) { ...@@ -233,6 +233,14 @@ void testTabulatedFunction(bool interpolating) {
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
} }
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
} }
void testCoulombLennardJones() { void testCoulombLennardJones() {
...@@ -317,8 +325,7 @@ int main() { ...@@ -317,8 +325,7 @@ int main() {
testExclusions(); testExclusions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(true); testTabulatedFunction();
testTabulatedFunction(false);
testCoulombLennardJones(); testCoulombLennardJones();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -91,9 +91,8 @@ void CustomGBForceProxy::serialize(const void* object, SerializationNode& node) ...@@ -91,9 +91,8 @@ void CustomGBForceProxy::serialize(const void* object, SerializationNode& node)
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating); SerializationNode& node = functions.createChildNode("Function").setStringProperty("name", name).setDoubleProperty("min", min).setDoubleProperty("max", max);
SerializationNode& node = functions.createChildNode("Function").setStringProperty("name", name).setDoubleProperty("min", min).setDoubleProperty("max", max).setIntProperty("interpolating", interpolating);
SerializationNode& valuesNode = node.createChildNode("Values"); SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++) for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]); valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
...@@ -152,7 +151,7 @@ void* CustomGBForceProxy::deserialize(const SerializationNode& node) const { ...@@ -152,7 +151,7 @@ void* CustomGBForceProxy::deserialize(const SerializationNode& node) const {
vector<double> values; vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++) for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v")); values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
force->addFunction(function.getStringProperty("name"), values, function.getDoubleProperty("min"), function.getDoubleProperty("max"), (bool) function.getIntProperty("interpolating")); force->addFunction(function.getStringProperty("name"), values, function.getDoubleProperty("min"), function.getDoubleProperty("max"));
} }
return force; return force;
} }
......
...@@ -96,9 +96,8 @@ void CustomHbondForceProxy::serialize(const void* object, SerializationNode& nod ...@@ -96,9 +96,8 @@ void CustomHbondForceProxy::serialize(const void* object, SerializationNode& nod
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating); SerializationNode& node = functions.createChildNode("Function").setStringProperty("name", name).setDoubleProperty("min", min).setDoubleProperty("max", max);
SerializationNode& node = functions.createChildNode("Function").setStringProperty("name", name).setDoubleProperty("min", min).setDoubleProperty("max", max).setIntProperty("interpolating", interpolating);
SerializationNode& valuesNode = node.createChildNode("Values"); SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++) for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]); valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
...@@ -164,7 +163,7 @@ void* CustomHbondForceProxy::deserialize(const SerializationNode& node) const { ...@@ -164,7 +163,7 @@ void* CustomHbondForceProxy::deserialize(const SerializationNode& node) const {
vector<double> values; vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++) for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v")); values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
force->addFunction(function.getStringProperty("name"), values, function.getDoubleProperty("min"), function.getDoubleProperty("max"), (bool) function.getIntProperty("interpolating")); force->addFunction(function.getStringProperty("name"), values, function.getDoubleProperty("min"), function.getDoubleProperty("max"));
} }
return force; return force;
} }
......
...@@ -78,9 +78,8 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode& ...@@ -78,9 +78,8 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode&
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating); SerializationNode& node = functions.createChildNode("Function").setStringProperty("name", name).setDoubleProperty("min", min).setDoubleProperty("max", max);
SerializationNode& node = functions.createChildNode("Function").setStringProperty("name", name).setDoubleProperty("min", min).setDoubleProperty("max", max).setIntProperty("interpolating", interpolating);
SerializationNode& valuesNode = node.createChildNode("Values"); SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++) for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]); valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
...@@ -129,7 +128,7 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons ...@@ -129,7 +128,7 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons
vector<double> values; vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++) for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v")); values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
force->addFunction(function.getStringProperty("name"), values, function.getDoubleProperty("min"), function.getDoubleProperty("max"), (bool) function.getIntProperty("interpolating")); force->addFunction(function.getStringProperty("name"), values, function.getDoubleProperty("min"), function.getDoubleProperty("max"));
} }
return force; return force;
} }
......
...@@ -63,7 +63,7 @@ void testSerialization() { ...@@ -63,7 +63,7 @@ void testSerialization() {
vector<double> values(10); vector<double> values(10);
for (int i = 0; i < 10; i++) for (int i = 0; i < 10; i++)
values[i] = sin((double) i); values[i] = sin((double) i);
force.addFunction("f", values, 0.5, 1.5, true); force.addFunction("f", values, 0.5, 1.5);
// Serialize and then deserialize it. // Serialize and then deserialize it.
...@@ -125,13 +125,11 @@ void testSerialization() { ...@@ -125,13 +125,11 @@ void testSerialization() {
string name1, name2; string name1, name2;
double min1, min2, max1, max2; double min1, min2, max1, max2;
vector<double> val1, val2; vector<double> val1, val2;
bool interp1, interp2; force.getFunctionParameters(i, name1, val1, min1, max1);
force.getFunctionParameters(i, name1, val1, min1, max1, interp1); force2.getFunctionParameters(i, name2, val2, min2, max2);
force2.getFunctionParameters(i, name2, val2, min2, max2, interp2);
ASSERT_EQUAL(name1, name2); ASSERT_EQUAL(name1, name2);
ASSERT_EQUAL(min1, min2); ASSERT_EQUAL(min1, min2);
ASSERT_EQUAL(max1, max2); ASSERT_EQUAL(max1, max2);
ASSERT_EQUAL(interp1, interp2);
ASSERT_EQUAL(val1.size(), val2.size()); ASSERT_EQUAL(val1.size(), val2.size());
for (int j = 0; j < val1.size(); j++) for (int j = 0; j < val1.size(); j++)
ASSERT_EQUAL(val1[j], val2[j]); ASSERT_EQUAL(val1[j], val2[j]);
......
...@@ -66,7 +66,7 @@ void testSerialization() { ...@@ -66,7 +66,7 @@ void testSerialization() {
vector<double> values(10); vector<double> values(10);
for (int i = 0; i < 10; i++) for (int i = 0; i < 10; i++)
values[i] = sin((double) i); values[i] = sin((double) i);
force.addFunction("f", values, 0.5, 1.5, true); force.addFunction("f", values, 0.5, 1.5);
// Serialize and then deserialize it. // Serialize and then deserialize it.
...@@ -130,13 +130,11 @@ void testSerialization() { ...@@ -130,13 +130,11 @@ void testSerialization() {
string name1, name2; string name1, name2;
double min1, min2, max1, max2; double min1, min2, max1, max2;
vector<double> val1, val2; vector<double> val1, val2;
bool interp1, interp2; force.getFunctionParameters(i, name1, val1, min1, max1);
force.getFunctionParameters(i, name1, val1, min1, max1, interp1); force2.getFunctionParameters(i, name2, val2, min2, max2);
force2.getFunctionParameters(i, name2, val2, min2, max2, interp2);
ASSERT_EQUAL(name1, name2); ASSERT_EQUAL(name1, name2);
ASSERT_EQUAL(min1, min2); ASSERT_EQUAL(min1, min2);
ASSERT_EQUAL(max1, max2); ASSERT_EQUAL(max1, max2);
ASSERT_EQUAL(interp1, interp2);
ASSERT_EQUAL(val1.size(), val2.size()); ASSERT_EQUAL(val1.size(), val2.size());
for (int j = 0; j < val1.size(); j++) for (int j = 0; j < val1.size(); j++)
ASSERT_EQUAL(val1[j], val2[j]); ASSERT_EQUAL(val1[j], val2[j]);
......
...@@ -59,7 +59,7 @@ void testSerialization() { ...@@ -59,7 +59,7 @@ void testSerialization() {
vector<double> values(10); vector<double> values(10);
for (int i = 0; i < 10; i++) for (int i = 0; i < 10; i++)
values[i] = sin((double) i); values[i] = sin((double) i);
force.addFunction("f", values, 0.5, 1.5, true); force.addFunction("f", values, 0.5, 1.5);
// Serialize and then deserialize it. // Serialize and then deserialize it.
...@@ -103,13 +103,11 @@ void testSerialization() { ...@@ -103,13 +103,11 @@ void testSerialization() {
string name1, name2; string name1, name2;
double min1, min2, max1, max2; double min1, min2, max1, max2;
vector<double> val1, val2; vector<double> val1, val2;
bool interp1, interp2; force.getFunctionParameters(i, name1, val1, min1, max1);
force.getFunctionParameters(i, name1, val1, min1, max1, interp1); force2.getFunctionParameters(i, name2, val2, min2, max2);
force2.getFunctionParameters(i, name2, val2, min2, max2, interp2);
ASSERT_EQUAL(name1, name2); ASSERT_EQUAL(name1, name2);
ASSERT_EQUAL(min1, min2); ASSERT_EQUAL(min1, min2);
ASSERT_EQUAL(max1, max2); ASSERT_EQUAL(max1, max2);
ASSERT_EQUAL(interp1, interp2);
ASSERT_EQUAL(val1.size(), val2.size()); ASSERT_EQUAL(val1.size(), val2.size());
for (int j = 0; j < val1.size(); j++) for (int j = 0; j < val1.size(); j++)
ASSERT_EQUAL(val1[j], val2[j]); ASSERT_EQUAL(val1[j], val2[j]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment