Commit 0fe7612b authored by peastman's avatar peastman
Browse files

Merge pull request #337 from peastman/functions

Created new API for tabulated functions
parents 7a7055b3 ed31a458
......@@ -92,16 +92,8 @@ void CustomHbondForceProxy::serialize(const void* object, SerializationNode& nod
exclusions.createChildNode("Exclusion").setIntProperty("donor", donor).setIntProperty("acceptor", acceptor);
}
SerializationNode& functions = node.createChildNode("Functions");
for (int i = 0; i < force.getNumFunctions(); i++) {
string name;
vector<double> values;
double min, max;
force.getFunctionParameters(i, name, values, min, max);
SerializationNode& node = functions.createChildNode("Function").setStringProperty("name", name).setDoubleProperty("min", min).setDoubleProperty("max", max);
SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
}
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions.createChildNode("Function", &force.getTabulatedFunction(i)).setStringProperty("name", force.getTabulatedFunctionName(i));
}
void* CustomHbondForceProxy::deserialize(const SerializationNode& node) const {
......@@ -159,11 +151,18 @@ void* CustomHbondForceProxy::deserialize(const SerializationNode& node) const {
const SerializationNode& functions = node.getChildNode("Functions");
for (int i = 0; i < (int) functions.getChildren().size(); i++) {
const SerializationNode& function = functions.getChildren()[i];
if (function.hasProperty("type")) {
force->addTabulatedFunction(function.getStringProperty("name"), function.decodeObject<TabulatedFunction>());
}
else {
// This is an old file created before TabulatedFunction existed.
const SerializationNode& valuesNode = function.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
force->addFunction(function.getStringProperty("name"), values, function.getDoubleProperty("min"), function.getDoubleProperty("max"));
force->addTabulatedFunction(function.getStringProperty("name"), new Continuous1DFunction(values, function.getDoubleProperty("min"), function.getDoubleProperty("max")));
}
}
return force;
}
......
......@@ -74,16 +74,8 @@ void CustomNonbondedForceProxy::serialize(const void* object, SerializationNode&
exclusions.createChildNode("Exclusion").setIntProperty("p1", particle1).setIntProperty("p2", particle2);
}
SerializationNode& functions = node.createChildNode("Functions");
for (int i = 0; i < force.getNumFunctions(); i++) {
string name;
vector<double> values;
double min, max;
force.getFunctionParameters(i, name, values, min, max);
SerializationNode& node = functions.createChildNode("Function").setStringProperty("name", name).setDoubleProperty("min", min).setDoubleProperty("max", max);
SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
}
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions.createChildNode("Function", &force.getTabulatedFunction(i)).setStringProperty("name", force.getTabulatedFunctionName(i));
}
void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) const {
......@@ -124,11 +116,18 @@ void* CustomNonbondedForceProxy::deserialize(const SerializationNode& node) cons
const SerializationNode& functions = node.getChildNode("Functions");
for (int i = 0; i < (int) functions.getChildren().size(); i++) {
const SerializationNode& function = functions.getChildren()[i];
if (function.hasProperty("type")) {
force->addTabulatedFunction(function.getStringProperty("name"), function.decodeObject<TabulatedFunction>());
}
else {
// This is an old file created before TabulatedFunction existed.
const SerializationNode& valuesNode = function.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
force->addFunction(function.getStringProperty("name"), values, function.getDoubleProperty("min"), function.getDoubleProperty("max"));
force->addTabulatedFunction(function.getStringProperty("name"), new Continuous1DFunction(values, function.getDoubleProperty("min"), function.getDoubleProperty("max")));
}
}
return force;
}
......
......@@ -30,6 +30,7 @@
* -------------------------------------------------------------------------- */
#include "openmm/AndersenThermostat.h"
#include "openmm/BrownianIntegrator.h"
#include "openmm/CMAPTorsionForce.h"
#include "openmm/CMMotionRemover.h"
#include "openmm/CustomAngleForce.h"
......@@ -38,26 +39,26 @@
#include "openmm/CustomExternalForce.h"
#include "openmm/CustomGBForce.h"
#include "openmm/CustomHbondForce.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/CustomNonbondedForce.h"
#include "openmm/CustomTorsionForce.h"
#include "openmm/HarmonicAngleForce.h"
#include "openmm/GBSAOBCForce.h"
#include "openmm/GBVIForce.h"
#include "openmm/HarmonicAngleForce.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/LangevinIntegrator.h"
#include "openmm/MonteCarloBarostat.h"
#include "openmm/NonbondedForce.h"
#include "openmm/PeriodicTorsionForce.h"
#include "openmm/RBTorsionForce.h"
#include "openmm/System.h"
#include "openmm/BrownianIntegrator.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/LangevinIntegrator.h"
#include "openmm/TabulatedFunction.h"
#include "openmm/VariableLangevinIntegrator.h"
#include "openmm/VariableVerletIntegrator.h"
#include "openmm/VerletIntegrator.h"
#include "openmm/serialization/SerializationProxy.h"
#include "openmm/serialization/BrownianIntegratorProxy.h"
#include "openmm/serialization/AndersenThermostatProxy.h"
#include "openmm/serialization/CMAPTorsionForceProxy.h"
#include "openmm/serialization/CMMotionRemoverProxy.h"
......@@ -67,23 +68,21 @@
#include "openmm/serialization/CustomExternalForceProxy.h"
#include "openmm/serialization/CustomGBForceProxy.h"
#include "openmm/serialization/CustomHbondForceProxy.h"
#include "openmm/serialization/CustomIntegratorProxy.h"
#include "openmm/serialization/CustomNonbondedForceProxy.h"
#include "openmm/serialization/CustomTorsionForceProxy.h"
#include "openmm/serialization/GBSAOBCForceProxy.h"
#include "openmm/serialization/GBVIForceProxy.h"
#include "openmm/serialization/HarmonicAngleForceProxy.h"
#include "openmm/serialization/HarmonicBondForceProxy.h"
#include "openmm/serialization/LangevinIntegratorProxy.h"
#include "openmm/serialization/MonteCarloBarostatProxy.h"
#include "openmm/serialization/NonbondedForceProxy.h"
#include "openmm/serialization/PeriodicTorsionForceProxy.h"
#include "openmm/serialization/RBTorsionForceProxy.h"
#include "openmm/serialization/SystemProxy.h"
#include "openmm/serialization/StateProxy.h"
#include "openmm/serialization/BrownianIntegratorProxy.h"
#include "openmm/serialization/CustomIntegratorProxy.h"
#include "openmm/serialization/LangevinIntegratorProxy.h"
#include "openmm/serialization/SystemProxy.h"
#include "openmm/serialization/TabulatedFunctionProxies.h"
#include "openmm/serialization/VariableLangevinIntegratorProxy.h"
#include "openmm/serialization/VariableVerletIntegratorProxy.h"
#include "openmm/serialization/VerletIntegratorProxy.h"
......@@ -104,29 +103,35 @@ using namespace OpenMM;
extern "C" void registerSerializationProxies() {
SerializationProxy::registerProxy(typeid(AndersenThermostat), new AndersenThermostatProxy());
SerializationProxy::registerProxy(typeid(BrownianIntegrator), new BrownianIntegratorProxy());
SerializationProxy::registerProxy(typeid(CMAPTorsionForce), new CMAPTorsionForceProxy());
SerializationProxy::registerProxy(typeid(CMMotionRemover), new CMMotionRemoverProxy());
SerializationProxy::registerProxy(typeid(Continuous1DFunction), new Continuous1DFunctionProxy());
SerializationProxy::registerProxy(typeid(Continuous2DFunction), new Continuous2DFunctionProxy());
SerializationProxy::registerProxy(typeid(Continuous3DFunction), new Continuous3DFunctionProxy());
SerializationProxy::registerProxy(typeid(CustomAngleForce), new CustomAngleForceProxy());
SerializationProxy::registerProxy(typeid(CustomBondForce), new CustomBondForceProxy());
SerializationProxy::registerProxy(typeid(CustomCompoundBondForce), new CustomCompoundBondForceProxy());
SerializationProxy::registerProxy(typeid(CustomExternalForce), new CustomExternalForceProxy());
SerializationProxy::registerProxy(typeid(CustomGBForce), new CustomGBForceProxy());
SerializationProxy::registerProxy(typeid(CustomHbondForce), new CustomHbondForceProxy());
SerializationProxy::registerProxy(typeid(CustomIntegrator), new CustomIntegratorProxy());
SerializationProxy::registerProxy(typeid(CustomNonbondedForce), new CustomNonbondedForceProxy());
SerializationProxy::registerProxy(typeid(CustomTorsionForce), new CustomTorsionForceProxy());
SerializationProxy::registerProxy(typeid(Discrete1DFunction), new Discrete1DFunctionProxy());
SerializationProxy::registerProxy(typeid(Discrete2DFunction), new Discrete2DFunctionProxy());
SerializationProxy::registerProxy(typeid(Discrete3DFunction), new Discrete3DFunctionProxy());
SerializationProxy::registerProxy(typeid(GBSAOBCForce), new GBSAOBCForceProxy());
SerializationProxy::registerProxy(typeid(GBVIForce), new GBVIForceProxy());
SerializationProxy::registerProxy(typeid(HarmonicAngleForce), new HarmonicAngleForceProxy());
SerializationProxy::registerProxy(typeid(HarmonicBondForce), new HarmonicBondForceProxy());
SerializationProxy::registerProxy(typeid(LangevinIntegrator), new LangevinIntegratorProxy());
SerializationProxy::registerProxy(typeid(MonteCarloBarostat), new MonteCarloBarostatProxy());
SerializationProxy::registerProxy(typeid(NonbondedForce), new NonbondedForceProxy());
SerializationProxy::registerProxy(typeid(PeriodicTorsionForce), new PeriodicTorsionForceProxy());
SerializationProxy::registerProxy(typeid(RBTorsionForce), new RBTorsionForceProxy());
SerializationProxy::registerProxy(typeid(System), new SystemProxy());
SerializationProxy::registerProxy(typeid(State), new StateProxy());
SerializationProxy::registerProxy(typeid(BrownianIntegrator), new BrownianIntegratorProxy());
SerializationProxy::registerProxy(typeid(CustomIntegrator), new CustomIntegratorProxy());
SerializationProxy::registerProxy(typeid(LangevinIntegrator), new LangevinIntegratorProxy());
SerializationProxy::registerProxy(typeid(VariableLangevinIntegrator), new VariableLangevinIntegratorProxy());
SerializationProxy::registerProxy(typeid(VariableVerletIntegrator), new VariableVerletIntegratorProxy());
SerializationProxy::registerProxy(typeid(VerletIntegrator), new VerletIntegratorProxy());
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/serialization/TabulatedFunctionProxies.h"
#include "openmm/serialization/SerializationNode.h"
#include "openmm/TabulatedFunction.h"
#include <sstream>
using namespace OpenMM;
using namespace std;
Continuous1DFunctionProxy::Continuous1DFunctionProxy() : SerializationProxy("Continuous1DFunction") {
}
void Continuous1DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const Continuous1DFunction& function = *reinterpret_cast<const Continuous1DFunction*>(object);
double min, max;
vector<double> values;
function.getFunctionParameters(values, min, max);
node.setDoubleProperty("min", min);
node.setDoubleProperty("max", max);
SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
}
void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
return new Continuous1DFunction(values, node.getDoubleProperty("min"), node.getDoubleProperty("max"));
}
Continuous2DFunctionProxy::Continuous2DFunctionProxy() : SerializationProxy("Continuous2DFunction") {
}
void Continuous2DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const Continuous2DFunction& function = *reinterpret_cast<const Continuous2DFunction*>(object);
int xsize, ysize;
double xmin, xmax, ymin, ymax;
vector<double> values;
function.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
node.setDoubleProperty("xsize", xsize);
node.setDoubleProperty("ysize", ysize);
node.setDoubleProperty("xmin", xmin);
node.setDoubleProperty("xmax", xmax);
node.setDoubleProperty("ymin", ymin);
node.setDoubleProperty("ymax", ymax);
SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
}
void* Continuous2DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
return new Continuous2DFunction(node.getIntProperty("xsize"), node.getIntProperty("ysize"), values,
node.getDoubleProperty("xmin"), node.getDoubleProperty("xmax"), node.getDoubleProperty("ymin"), node.getDoubleProperty("ymax"));
}
Continuous3DFunctionProxy::Continuous3DFunctionProxy() : SerializationProxy("Continuous3DFunction") {
}
void Continuous3DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const Continuous3DFunction& function = *reinterpret_cast<const Continuous3DFunction*>(object);
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
vector<double> values;
function.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
node.setDoubleProperty("xsize", xsize);
node.setDoubleProperty("ysize", ysize);
node.setDoubleProperty("zsize", zsize);
node.setDoubleProperty("xmin", xmin);
node.setDoubleProperty("xmax", xmax);
node.setDoubleProperty("ymin", ymin);
node.setDoubleProperty("ymax", ymax);
node.setDoubleProperty("zmin", zmin);
node.setDoubleProperty("zmax", zmax);
SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
}
void* Continuous3DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
return new Continuous3DFunction(node.getIntProperty("xsize"), node.getIntProperty("ysize"), node.getIntProperty("zsize"), values,
node.getDoubleProperty("xmin"), node.getDoubleProperty("xmax"), node.getDoubleProperty("ymin"), node.getDoubleProperty("ymax"),
node.getDoubleProperty("zmin"), node.getDoubleProperty("zmax"));
}
Discrete1DFunctionProxy::Discrete1DFunctionProxy() : SerializationProxy("Discrete1DFunction") {
}
void Discrete1DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const Discrete1DFunction& function = *reinterpret_cast<const Discrete1DFunction*>(object);
vector<double> values;
function.getFunctionParameters(values);
SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
}
void* Discrete1DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
return new Discrete1DFunction(values);
}
Discrete2DFunctionProxy::Discrete2DFunctionProxy() : SerializationProxy("Discrete2DFunction") {
}
void Discrete2DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const Discrete2DFunction& function = *reinterpret_cast<const Discrete2DFunction*>(object);
int xsize, ysize;
vector<double> values;
function.getFunctionParameters(xsize, ysize, values);
node.setDoubleProperty("xsize", xsize);
node.setDoubleProperty("ysize", ysize);
SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
}
void* Discrete2DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
return new Discrete2DFunction(node.getIntProperty("xsize"), node.getIntProperty("ysize"), values);
}
Discrete3DFunctionProxy::Discrete3DFunctionProxy() : SerializationProxy("Discrete3DFunction") {
}
void Discrete3DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const Discrete3DFunction& function = *reinterpret_cast<const Discrete3DFunction*>(object);
int xsize, ysize, zsize;
vector<double> values;
function.getFunctionParameters(xsize, ysize, zsize, values);
node.setDoubleProperty("xsize", xsize);
node.setDoubleProperty("ysize", ysize);
node.setDoubleProperty("zsize", zsize);
SerializationNode& valuesNode = node.createChildNode("Values");
for (int j = 0; j < (int) values.size(); j++)
valuesNode.createChildNode("Value").setDoubleProperty("v", values[j]);
}
void* Discrete3DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values;
for (int j = 0; j < (int) valuesNode.getChildren().size(); j++)
values.push_back(valuesNode.getChildren()[j].getDoubleProperty("v"));
return new Discrete3DFunction(node.getIntProperty("xsize"), node.getIntProperty("ysize"), node.getIntProperty("zsize"), values);
}
......@@ -62,6 +62,10 @@ void testSerialization() {
particles[2] = 1;
params[0] = 2.1;
force.addBond(particles, params);
vector<double> values(10);
for (int i = 0; i < 10; i++)
values[i] = sin((double) i);
force.addTabulatedFunction("f", new Continuous1DFunction(values, 0.5, 1.5));
// Serialize and then deserialize it.
......@@ -95,6 +99,19 @@ void testSerialization() {
for (int j = 0; j < (int) particles1.size(); j++)
ASSERT_EQUAL(particles1[j], particles2[j]);
}
ASSERT_EQUAL(force.getNumTabulatedFunctions(), force2.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
double min1, min2, max1, max2;
vector<double> val1, val2;
dynamic_cast<Continuous1DFunction&>(force.getTabulatedFunction(i)).getFunctionParameters(val1, min1, max1);
dynamic_cast<Continuous1DFunction&>(force2.getTabulatedFunction(i)).getFunctionParameters(val2, min2, max2);
ASSERT_EQUAL(force.getTabulatedFunctionName(i), force2.getTabulatedFunctionName(i));
ASSERT_EQUAL(min1, min2);
ASSERT_EQUAL(max1, max2);
ASSERT_EQUAL(val1.size(), val2.size());
for (int j = 0; j < (int) val1.size(); j++)
ASSERT_EQUAL(val1[j], val2[j]);
}
}
int main() {
......
......@@ -63,7 +63,7 @@ void testSerialization() {
vector<double> values(10);
for (int i = 0; i < 10; i++)
values[i] = sin((double) i);
force.addFunction("f", values, 0.5, 1.5);
force.addTabulatedFunction("f", new Discrete1DFunction(values));
// Serialize and then deserialize it.
......@@ -120,16 +120,12 @@ void testSerialization() {
ASSERT_EQUAL(a1, a2);
ASSERT_EQUAL(b1, b2);
}
ASSERT_EQUAL(force.getNumFunctions(), force2.getNumFunctions());
for (int i = 0; i < force.getNumFunctions(); i++) {
string name1, name2;
double min1, min2, max1, max2;
ASSERT_EQUAL(force.getNumTabulatedFunctions(), force2.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
vector<double> val1, val2;
force.getFunctionParameters(i, name1, val1, min1, max1);
force2.getFunctionParameters(i, name2, val2, min2, max2);
ASSERT_EQUAL(name1, name2);
ASSERT_EQUAL(min1, min2);
ASSERT_EQUAL(max1, max2);
dynamic_cast<Discrete1DFunction&>(force.getTabulatedFunction(i)).getFunctionParameters(val1);
dynamic_cast<Discrete1DFunction&>(force2.getTabulatedFunction(i)).getFunctionParameters(val2);
ASSERT_EQUAL(force.getTabulatedFunctionName(i), force2.getTabulatedFunctionName(i));
ASSERT_EQUAL(val1.size(), val2.size());
for (int j = 0; j < (int) val1.size(); j++)
ASSERT_EQUAL(val1[j], val2[j]);
......
......@@ -66,7 +66,7 @@ void testSerialization() {
vector<double> values(10);
for (int i = 0; i < 10; i++)
values[i] = sin((double) i);
force.addFunction("f", values, 0.5, 1.5);
force.addTabulatedFunction("f", new Discrete1DFunction(values));
// Serialize and then deserialize it.
......@@ -125,16 +125,12 @@ void testSerialization() {
ASSERT_EQUAL(a1, a2);
ASSERT_EQUAL(b1, b2);
}
ASSERT_EQUAL(force.getNumFunctions(), force2.getNumFunctions());
for (int i = 0; i < force.getNumFunctions(); i++) {
string name1, name2;
double min1, min2, max1, max2;
ASSERT_EQUAL(force.getNumTabulatedFunctions(), force2.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
vector<double> val1, val2;
force.getFunctionParameters(i, name1, val1, min1, max1);
force2.getFunctionParameters(i, name2, val2, min2, max2);
ASSERT_EQUAL(name1, name2);
ASSERT_EQUAL(min1, min2);
ASSERT_EQUAL(max1, max2);
dynamic_cast<Discrete1DFunction&>(force.getTabulatedFunction(i)).getFunctionParameters(val1);
dynamic_cast<Discrete1DFunction&>(force2.getTabulatedFunction(i)).getFunctionParameters(val2);
ASSERT_EQUAL(force.getTabulatedFunctionName(i), force2.getTabulatedFunctionName(i));
ASSERT_EQUAL(val1.size(), val2.size());
for (int j = 0; j < (int) val1.size(); j++)
ASSERT_EQUAL(val1[j], val2[j]);
......
......@@ -98,14 +98,13 @@ void testSerialization() {
ASSERT_EQUAL(a1, a2);
ASSERT_EQUAL(b1, b2);
}
ASSERT_EQUAL(force.getNumFunctions(), force2.getNumFunctions());
for (int i = 0; i < force.getNumFunctions(); i++) {
string name1, name2;
ASSERT_EQUAL(force.getNumTabulatedFunctions(), force2.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
double min1, min2, max1, max2;
vector<double> val1, val2;
force.getFunctionParameters(i, name1, val1, min1, max1);
force2.getFunctionParameters(i, name2, val2, min2, max2);
ASSERT_EQUAL(name1, name2);
dynamic_cast<Continuous1DFunction&>(force.getTabulatedFunction(i)).getFunctionParameters(val1, min1, max1);
dynamic_cast<Continuous1DFunction&>(force2.getTabulatedFunction(i)).getFunctionParameters(val2, min2, max2);
ASSERT_EQUAL(force.getTabulatedFunctionName(i), force2.getTabulatedFunctionName(i));
ASSERT_EQUAL(min1, min2);
ASSERT_EQUAL(max1, max2);
ASSERT_EQUAL(val1.size(), val2.size());
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2010-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/TabulatedFunction.h"
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/serialization/XmlSerializer.h"
#include <iostream>
#include <sstream>
using namespace OpenMM;
using namespace std;
void testContinuous1DFunction() {
// Create a function.
double min = 0.5, max = 1.5;
vector<double> values(60);
for (int i = 0; i < (int) values.size(); i++)
values[i] = sin((double) i);
Continuous1DFunction function(values, min, max);
// Serialize and then deserialize it.
stringstream buffer;
XmlSerializer::serialize<Continuous1DFunction>(&function, "Function", buffer);
Continuous1DFunction* copy = XmlSerializer::deserialize<Continuous1DFunction>(buffer);
// Compare the two forces to see if they are identical.
double min2, max2;
vector<double> values2;
copy->getFunctionParameters(values2, min2, max2);
ASSERT_EQUAL(min, min2);
ASSERT_EQUAL(max, max2);
ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]);
}
void testContinuous2DFunction() {
// Create a function.
int xsize = 5, ysize = 12;
double xmin = 0.5, xmax = 1.5, ymin = 0.1, ymax = 5.0;
vector<double> values(xsize*ysize);
for (int i = 0; i < (int) values.size(); i++)
values[i] = sin((double) i);
Continuous2DFunction function(xsize, ysize, values, xmin, xmax, ymin, ymax);
// Serialize and then deserialize it.
stringstream buffer;
XmlSerializer::serialize<Continuous2DFunction>(&function, "Function", buffer);
Continuous2DFunction* copy = XmlSerializer::deserialize<Continuous2DFunction>(buffer);
// Compare the two forces to see if they are identical.
int xsize2, ysize2;
double xmin2, xmax2, ymin2, ymax2;
vector<double> values2;
copy->getFunctionParameters(xsize2, ysize2, values2, xmin2, xmax2, ymin2, ymax2);
ASSERT_EQUAL(xsize, xsize2);
ASSERT_EQUAL(ysize, ysize2);
ASSERT_EQUAL(xmin, xmin2);
ASSERT_EQUAL(xmax, xmax2);
ASSERT_EQUAL(ymin, ymin2);
ASSERT_EQUAL(ymax, ymax2);
ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]);
}
void testContinuous3DFunction() {
// Create a function.
int xsize = 5, ysize = 4, zsize = 3;
double xmin = 0.5, xmax = 1.5, ymin = 0.1, ymax = 5.0, zmin = 0.3, zmax = 0.9;
vector<double> values(xsize*ysize*zsize);
for (int i = 0; i < (int) values.size(); i++)
values[i] = sin((double) i);
Continuous3DFunction function(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
// Serialize and then deserialize it.
stringstream buffer;
XmlSerializer::serialize<Continuous3DFunction>(&function, "Function", buffer);
Continuous3DFunction* copy = XmlSerializer::deserialize<Continuous3DFunction>(buffer);
// Compare the two forces to see if they are identical.
int xsize2, ysize2, zsize2;
double xmin2, xmax2, ymin2, ymax2, zmin2, zmax2;
vector<double> values2;
copy->getFunctionParameters(xsize2, ysize2, zsize2, values2, xmin2, xmax2, ymin2, ymax2, zmin2, zmax2);
ASSERT_EQUAL(xsize, xsize2);
ASSERT_EQUAL(ysize, ysize2);
ASSERT_EQUAL(zsize, zsize2);
ASSERT_EQUAL(xmin, xmin2);
ASSERT_EQUAL(xmax, xmax2);
ASSERT_EQUAL(ymin, ymin2);
ASSERT_EQUAL(ymax, ymax2);
ASSERT_EQUAL(zmin, zmin2);
ASSERT_EQUAL(zmax, zmax2);
ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]);
}
void testDiscrete1DFunction() {
// Create a function.
vector<double> values(60);
for (int i = 0; i < (int) values.size(); i++)
values[i] = sin((double) i);
Discrete1DFunction function(values);
// Serialize and then deserialize it.
stringstream buffer;
XmlSerializer::serialize<Discrete1DFunction>(&function, "Function", buffer);
Discrete1DFunction* copy = XmlSerializer::deserialize<Discrete1DFunction>(buffer);
// Compare the two forces to see if they are identical.
vector<double> values2;
copy->getFunctionParameters(values2);
ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]);
}
void testDiscrete2DFunction() {
// Create a function.
int xsize = 5, ysize = 12;
vector<double> values(xsize*ysize);
for (int i = 0; i < (int) values.size(); i++)
values[i] = sin((double) i);
Discrete2DFunction function(xsize, ysize, values);
// Serialize and then deserialize it.
stringstream buffer;
XmlSerializer::serialize<Discrete2DFunction>(&function, "Function", buffer);
Discrete2DFunction* copy = XmlSerializer::deserialize<Discrete2DFunction>(buffer);
// Compare the two forces to see if they are identical.
int xsize2, ysize2;
vector<double> values2;
copy->getFunctionParameters(xsize2, ysize2, values2);
ASSERT_EQUAL(xsize, xsize2);
ASSERT_EQUAL(ysize, ysize2);
ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]);
}
void testDiscrete3DFunction() {
// Create a function.
int xsize = 5, ysize = 4, zsize = 3;
vector<double> values(xsize*ysize*zsize);
for (int i = 0; i < (int) values.size(); i++)
values[i] = sin((double) i);
Discrete3DFunction function(xsize, ysize, zsize, values);
// Serialize and then deserialize it.
stringstream buffer;
XmlSerializer::serialize<Discrete3DFunction>(&function, "Function", buffer);
Discrete3DFunction* copy = XmlSerializer::deserialize<Discrete3DFunction>(buffer);
// Compare the two forces to see if they are identical.
int xsize2, ysize2, zsize2;
vector<double> values2;
copy->getFunctionParameters(xsize2, ysize2, zsize2, values2);
ASSERT_EQUAL(xsize, xsize2);
ASSERT_EQUAL(ysize, ysize2);
ASSERT_EQUAL(zsize, zsize2);
ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]);
}
int main() {
try {
testContinuous1DFunction();
testContinuous2DFunction();
testContinuous3DFunction();
testDiscrete1DFunction();
testDiscrete2DFunction();
testDiscrete3DFunction();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
......@@ -84,10 +84,91 @@ void testPeriodicSpline() {
}
}
void test2DSpline() {
const int xsize = 15;
const int ysize = 17;
vector<double> x(xsize);
vector<double> y(ysize);
vector<double> f(xsize*ysize);
for (int i = 0; i < xsize; i++)
x[i] = 0.5*i+0.1*sin(double(i));
for (int i = 0; i < ysize; i++)
y[i] = 0.6*i+0.1*sin(double(i));
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
f[i+j*xsize] = sin(x[i])*cos(0.4*y[j]);
vector<vector<double> > c;
SplineFitter::create2DNaturalSpline(x, y, f, c);
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++) {
double value = SplineFitter::evaluate2DSpline(x, y, f, c, x[i], y[j]);
ASSERT_EQUAL_TOL(f[i+j*xsize], value, 1e-6);
}
for (int i = 0; i < 10; i++) {
for (int j = 0; j < 10; j++) {
double s = x[0]+(i+1)*(x[xsize-1]-x[0])/12.0;
double t = y[0]+(j+1)*(y[ysize-1]-y[0])/12.0;
double value = SplineFitter::evaluate2DSpline(x, y, f, c, s, t);
ASSERT_EQUAL_TOL(sin(s)*cos(0.4*t), value, 0.02);
double dx, dy;
SplineFitter::evaluate2DSplineDerivatives(x, y, f, c, s, t, dx, dy);
ASSERT_EQUAL_TOL(cos(s)*cos(0.4*t), dx, 0.05);
ASSERT_EQUAL_TOL(-0.4*sin(s)*sin(0.4*t), dy, 0.05);
}
}
}
void test3DSpline() {
const int xsize = 8;
const int ysize = 9;
const int zsize = 10;
vector<double> x(xsize);
vector<double> y(ysize);
vector<double> z(zsize);
vector<double> f(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++)
x[i] = 0.2*i+0.02*sin(0.4*double(i));
for (int i = 0; i < ysize; i++)
y[i] = 0.2*i+0.02*sin(0.45*double(i));
for (int i = 0; i < zsize; i++)
z[i] = 0.2*i+0.02*sin(0.5*double(i));
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
f[i+j*xsize+k*xsize*ysize] = sin(x[i])*cos(0.4*y[j])*(1+z[k]);
vector<vector<double> > c;
SplineFitter::create3DNaturalSpline(x, y, z, f, c);
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double value = SplineFitter::evaluate3DSpline(x, y, z, f, c, x[i], y[j], z[k]);
ASSERT_EQUAL_TOL(f[i+j*xsize+k*xsize*ysize], value, 1e-6);
}
}
for (int i = 0; i < 10; i++) {
for (int j = 0; j < 10; j++) {
for (int k = 0; k < 10; k++) {
double s = x[0]+(i+1)*(x[xsize-1]-x[0])/12.0;
double t = y[0]+(j+1)*(y[ysize-1]-y[0])/12.0;
double u = z[0]+(k+1)*(z[zsize-1]-z[0])/12.0;
double value = SplineFitter::evaluate3DSpline(x, y, z, f, c, s, t, u);
ASSERT_EQUAL_TOL(sin(s)*cos(0.4*t)*(1+u), value, 0.02);
double dx, dy, dz;
SplineFitter::evaluate3DSplineDerivatives(x, y, z, f, c, s, t, u, dx, dy, dz);
ASSERT_EQUAL_TOL(cos(s)*cos(0.4*t)*(1+u), dx, 0.1);
ASSERT_EQUAL_TOL(-0.4*sin(s)*sin(0.4*t)*(1+u), dy, 0.1);
ASSERT_EQUAL_TOL(sin(s)*cos(0.4*t), dz, 0.1);
}
}
}
}
int main() {
try {
testNaturalSpline();
testPeriodicSpline();
test2DSpline();
test3DSpline();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
......@@ -42,7 +42,7 @@ def getText(subNodePath, node):
def convertOpenMMPrefix(name):
return name.replace('OpenMM::', 'OpenMM_')
OPENMM_RE_PATTERN=re.compile("(.*)OpenMM:[a-zA-Z:]*:(.*)")
OPENMM_RE_PATTERN=re.compile("(.*)OpenMM:[a-zA-Z0-9_:]*:(.*)")
def stripOpenMMPrefix(name, rePattern=OPENMM_RE_PATTERN):
try:
m=rePattern.search(name)
......
......@@ -4151,7 +4151,7 @@
<EnergyTerm type="ParticlePairNoExclusions">
include(chargeGroup1*numChargeGroups+chargeGroup2)*screening1*screening2*138.935456*charge1*charge2/r
</EnergyTerm>
<Function name="sigma" min="0" max="675">
<Function name="sigma" type="Discrete1D">
0.27 0.3 0.285 0.185 0.27 0.285 0.315 0.27 0.235 0.295 0.27 0.30152225 0.355862 0.2932777648 0.235 0.29253280555 0.235 0.29267889715 0.235 0.3817314 0.366188 0.47079995 0.405 0.315 0.3099 0.241
0.3 0.33 0.315 0.265 0.3 0.315 0.345 0.3 0.265 0.325 0.3 0.33152225 0.385862 0.3232777648 0.265 0.32253280555 0.265 0.32267889715 0.265 0.4117314 0.396188 0.50079995 0.435 0.345 0.3399 0.271
0.285 0.315 0.3 0.25 0.285 0.3 0.33 0.285 0.25 0.31 0.285 0.31652225 0.370862 0.3082777648 0.25 0.30753280555 0.25 0.30767889715 0.25 0.3967314 0.381188 0.48579995 0.42 0.33 0.3249 0.256
......
......@@ -1509,7 +1509,17 @@ class CustomGBGenerator:
generator.energyTerms.append((term.text, computationMap[term.attrib['type']]))
for function in element.findall("Function"):
values = [float(x) for x in function.text.split()]
generator.functions.append((function.attrib['name'], values, float(function.attrib['min']), float(function.attrib['max'])))
if 'type' in function.attrib:
type = function.attrib['type']
else:
type = 'Continuous1D'
params = {}
for key in function.attrib:
if key.endswith('size'):
params[key] = int(function.attrib[key])
elif key.endswith('min') or key.endswith('max'):
params[key] = float(function.attrib[key])
generator.functions.append((function.attrib['name'], type, values, params))
def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
methodMap = {NoCutoff:mm.CustomGBForce.NoCutoff,
......@@ -1526,8 +1536,19 @@ class CustomGBGenerator:
force.addComputedValue(value[0], value[1], value[2])
for term in self.energyTerms:
force.addEnergyTerm(term[0], term[1])
for function in self.functions:
force.addFunction(function[0], function[1], function[2], function[3])
for (name, type, values, params) in self.functions:
if type == 'Continuous1D':
force.addTabulatedFunction(name, mm.Continuous1DFunction(values, params['min'], params['max']))
elif type == 'Continuous2D':
force.addTabulatedFunction(name, mm.Continuous2DFunction(params['xsize'], params['ysize'], values, params['xmin'], params['xmax'], params['ymin'], params['ymax']))
elif type == 'Continuous3D':
force.addTabulatedFunction(name, mm.Continuous2DFunction(params['xsize'], params['ysize'], params['zsize'], values, params['xmin'], params['xmax'], params['ymin'], params['ymax'], params['zmin'], params['zmax']))
elif type == 'Discrete1D':
force.addTabulatedFunction(name, mm.Discrete1DFunction(values))
elif type == 'Discrete2D':
force.addTabulatedFunction(name, mm.Discrete2DFunction(params['xsize'], params['ysize'], values))
elif type == 'Discrete3D':
force.addTabulatedFunction(name, mm.Discrete2DFunction(params['xsize'], params['ysize'], params['zsize'], values))
for atom in data.atoms:
t = data.atomType[atom]
if t in self.typeMap:
......
......@@ -6,9 +6,9 @@ Simbios, the NIH National Center for Physics-Based Simulation of
Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2012 University of Virginia and the Authors.
Portions copyright (c) 2012-2014 University of Virginia and the Authors.
Authors: Christoph Klein, Michael R. Shirts
Contributors: Jason M. Swails
Contributors: Jason M. Swails, Peter Eastman
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
......@@ -31,7 +31,7 @@ USE OR OTHER DEALINGS IN THE SOFTWARE.
from __future__ import division
from simtk.openmm import CustomGBForce
from simtk.openmm import CustomGBForce, Discrete1DFunction
d0=[2.26685,2.32548,2.38397,2.44235,2.50057,2.55867,2.61663,2.67444,
2.73212,2.78965,2.84705,2.9043,2.96141,3.0184,3.07524,3.13196,
......@@ -331,8 +331,8 @@ def GBSAGBnForce(solventDielectric=78.5, soluteDielectric=1, SA=None,
custom.addGlobalParameter("neckScale", 0.361825)
custom.addGlobalParameter("neckCut", 0.68)
custom.addFunction("getd0", d0, 0, 440)
custom.addFunction("getm0", m0, 0, 440)
custom.addTabulatedFunction("getd0", Discrete1DFunction(d0))
custom.addTabulatedFunction("getm0", Discrete1DFunction(m0))
custom.addComputedValue("I", "Ivdw+neckScale*Ineck;"
"Ineck=step(radius1+radius2+neckCut-r)*getm0(index)/(1+100*(r-getd0(index))^2+0.3*1000000*(r-getd0(index))^6);"
......@@ -380,8 +380,8 @@ def GBSAGBn2Force(solventDielectric=78.5, soluteDielectric=1, SA=None,
custom.addGlobalParameter("neckScale", 0.826836)
custom.addGlobalParameter("neckCut", 0.68)
custom.addFunction("getd0", d0, 0, 440)
custom.addFunction("getm0", m0, 0, 440)
custom.addTabulatedFunction("getd0", Discrete1DFunction(d0))
custom.addTabulatedFunction("getm0", Discrete1DFunction(m0))
custom.addComputedValue("I", "Ivdw+neckScale*Ineck;"
"Ineck=step(radius1+radius2+neckCut-r)*getm0(index)/(1+100*(r-getd0(index))^2+0.3*1000000*(r-getd0(index))^6);"
......
......@@ -136,6 +136,10 @@ NO_OUTPUT_ARGS = [('LocalEnergyMinimizer', 'minimize', 'context'),
STEAL_OWNERSHIP = {("Platform", "registerPlatform") : [0],
("System", "addForce") : [0],
("System", "setVirtualSite") : [1],
("CustomNonbondedForce", "addTabulatedFunction") : [1],
("CustomGBForce", "addTabulatedFunction") : [1],
("CustomHbondForce", "addTabulatedFunction") : [1],
("CustomCompoundBondForce", "addTabulatedFunction") : [1],
}
# This is a list of units to attach to return values and method args.
......@@ -179,6 +183,7 @@ UNITS = {
("*", "getSolventDielectric") : (None, ()),
("*", "getStepSize") : ("unit.picosecond", ()),
("*", "getSystem") : (None, ()),
("*", "getTabulatedFunction") : (None, ()),
("*", "getUseDispersionCorrection") : (None, ()),
("*", "getTemperature") : ("unit.kelvin", ()),
("*", "getUseDispersionCorrection") : (None, ()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment