Commit 112fd213 authored by Charlles Abreu's avatar Charlles Abreu
Browse files

Tests for periodic tabulated functions

parent 5696464c
......@@ -243,13 +243,14 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
if (periodic) {
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[5]<< ";\n";
out << "x = (x - floor(x))*" << paramsFloat[6] << ";\n";
out << "int index = (int) (floor(x));\n";
}
else {
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) " << paramsInt[3] << ");\n";
}
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
......@@ -272,14 +273,16 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[10] << ";\n";
out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n";
out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n";
out << "int s = (int) floor(x);\n";
out << "int t = (int) floor(y);\n";
}
else {
out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
}
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n";
for (int j = 0; j < 4; j++)
......@@ -326,16 +329,19 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n";
out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n";
out << "z = (z - floor(z))*" << paramsFloat[2] << ";\n";
out << "int s = (int) floor(x);\n";
out << "int t = (int) floor(y);\n";
out << "int u = (int) floor(z);\n";
}
else {
out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n";
}
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n";
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
out << "float4 c[16];\n";
for (int j = 0; j < 16; j++)
......
......@@ -29,6 +29,9 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#ifdef WIN32
#define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include "openmm/TabulatedFunction.h"
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/serialization/XmlSerializer.h"
......
......@@ -365,12 +365,11 @@ void testPeriodicContinuous1DFunction() {
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
double twoPi = 8.0*atan(1.0);
for (int i = 0; i < 20; i++)
table.push_back(sin(i*twoPi/20.0));
table.push_back(table[0]);
Continuous1DFunction* continuous1DFunction = new Continuous1DFunction(table, 1.0, twoPi+1.0);
int xsize = 20;
vector<double> table(xsize);
for (int i = 0; i < xsize; i++)
table[i] = sin(2.0*M_PI*i/(xsize-1));
Continuous1DFunction* continuous1DFunction = new Continuous1DFunction(table, 1.0, 2.0*M_PI+1.0, true);
forceField->addTabulatedFunction("fn", continuous1DFunction);
system.addForce(forceField);
Context context(system, integrator, platform);
......@@ -388,8 +387,8 @@ void testPeriodicContinuous1DFunction() {
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
for (int i = 1; i < 20; i++) {
double x = i*twoPi/20.0+1.0;
for (int i = 1; i < xsize; i++) {
double x = 2.0*M_PI*i/(xsize-1)+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
......@@ -446,6 +445,50 @@ void testContinuous2DFunction() {
}
}
void testPeriodicContinuous2DFunction() {
const int xsize = 20;
const int ysize = 21;
const double xmin = 1.0;
const double xmax = 1.0+8.0*M_PI;
const double ymin = 0.0;
const double ymax = 2.0*M_PI;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
double x = xmin + i*(xmax-xmin)/(xsize-1);
double y = ymin + j*(ymax-ymin)/(ysize-1);
table[i+xsize*j] = sin(0.25*x)*cos(y);
}
}
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax, true));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 1.0) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.5) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = sin(0.25*x)*cos(y)+1.0;
double force = -0.25*cos(0.25*x)*cos(y);
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
}
}
void testContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
......@@ -504,6 +547,60 @@ void testContinuous3DFunction() {
}
}
void testPeriodicContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 1.0;
const double xmax = 1.0+8.0*M_PI;
const double ymin = 0.0;
const double ymax = 2.0*M_PI;
const double zmin = 0.0;
const double zmax = 2.0*M_PI;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/(xsize-1);
double y = ymin + j*(ymax-ymin)/(ysize-1);
double z = zmin + k*(zmax-zmin)/(zsize-1);
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(y)*(1.0-sin(z));
}
}
}
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax, true));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin; x < xmax+0.2; x += 1.0) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.5) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.5) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setParameter("b", z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = sin(0.25*x)*cos(y)*(1.0-sin(z))+1.0;
double force = -0.25*cos(0.25*x)*cos(y)*(1.0-sin(z));
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
void testDiscrete1DFunction() {
System system;
system.addParticle(1.0);
......@@ -1345,8 +1442,11 @@ int main(int argc, char* argv[]) {
testPeriodic();
testTriclinic();
testContinuous1DFunction();
testPeriodicContinuous1DFunction();
testContinuous2DFunction();
testPeriodicContinuous2DFunction();
testContinuous3DFunction();
testPeriodicContinuous3DFunction();
testDiscrete1DFunction();
testDiscrete2DFunction();
testDiscrete3DFunction();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment