Commit 112fd213 authored by Charlles Abreu's avatar Charlles Abreu
Browse files

Tests for periodic tabulated functions

parent 5696464c
...@@ -243,13 +243,14 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -243,13 +243,14 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
if (periodic) { if (periodic) {
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[5]<< ";\n"; out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[5]<< ";\n";
out << "x = (x - floor(x))*" << paramsFloat[6] << ";\n"; out << "x = (x - floor(x))*" << paramsFloat[6] << ";\n";
out << "int index = (int) (floor(x));\n";
} }
else { else {
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n"; out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n"; out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) " << paramsInt[3] << ");\n";
} }
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n"; out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n"; out << "real a = 1.0f-b;\n";
...@@ -272,14 +273,16 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -272,14 +273,16 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[10] << ";\n"; out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[10] << ";\n";
out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n"; out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n";
out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n"; out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n";
out << "int s = (int) floor(x);\n";
out << "int t = (int) floor(y);\n";
} }
else { else {
out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n"; out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n"; out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n"; out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
} }
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n"; out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n"; out << "float4 c[4];\n";
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
...@@ -326,16 +329,19 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -326,16 +329,19 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n"; out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n";
out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n"; out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n";
out << "z = (z - floor(z))*" << paramsFloat[2] << ";\n"; out << "z = (z - floor(z))*" << paramsFloat[2] << ";\n";
out << "int s = (int) floor(x);\n";
out << "int t = (int) floor(y);\n";
out << "int u = (int) floor(z);\n";
} }
else { else {
out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n"; out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n"; out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n"; out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n"; out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n";
} }
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n";
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n"; out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
out << "float4 c[16];\n"; out << "float4 c[16];\n";
for (int j = 0; j < 16; j++) for (int j = 0; j < 16; j++)
......
...@@ -29,6 +29,9 @@ ...@@ -29,6 +29,9 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#ifdef WIN32
#define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include "openmm/TabulatedFunction.h" #include "openmm/TabulatedFunction.h"
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "openmm/serialization/XmlSerializer.h" #include "openmm/serialization/XmlSerializer.h"
......
...@@ -365,12 +365,11 @@ void testPeriodicContinuous1DFunction() { ...@@ -365,12 +365,11 @@ void testPeriodicContinuous1DFunction() {
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r)+1"); CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r)+1");
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; int xsize = 20;
double twoPi = 8.0*atan(1.0); vector<double> table(xsize);
for (int i = 0; i < 20; i++) for (int i = 0; i < xsize; i++)
table.push_back(sin(i*twoPi/20.0)); table[i] = sin(2.0*M_PI*i/(xsize-1));
table.push_back(table[0]); Continuous1DFunction* continuous1DFunction = new Continuous1DFunction(table, 1.0, 2.0*M_PI+1.0, true);
Continuous1DFunction* continuous1DFunction = new Continuous1DFunction(table, 1.0, twoPi+1.0);
forceField->addTabulatedFunction("fn", continuous1DFunction); forceField->addTabulatedFunction("fn", continuous1DFunction);
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
...@@ -388,8 +387,8 @@ void testPeriodicContinuous1DFunction() { ...@@ -388,8 +387,8 @@ void testPeriodicContinuous1DFunction() {
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
} }
for (int i = 1; i < 20; i++) { for (int i = 1; i < xsize; i++) {
double x = i*twoPi/20.0+1.0; double x = 2.0*M_PI*i/(xsize-1)+1.0;
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Energy); State state = context.getState(State::Energy);
...@@ -446,6 +445,50 @@ void testContinuous2DFunction() { ...@@ -446,6 +445,50 @@ void testContinuous2DFunction() {
} }
} }
void testPeriodicContinuous2DFunction() {
const int xsize = 20;
const int ysize = 21;
const double xmin = 1.0;
const double xmax = 1.0+8.0*M_PI;
const double ymin = 0.0;
const double ymax = 2.0*M_PI;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
double x = xmin + i*(xmax-xmin)/(xsize-1);
double y = ymin + j*(ymax-ymin)/(ysize-1);
table[i+xsize*j] = sin(0.25*x)*cos(y);
}
}
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax, true));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 1.0) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.5) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = sin(0.25*x)*cos(y)+1.0;
double force = -0.25*cos(0.25*x)*cos(y);
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
}
}
void testContinuous3DFunction() { void testContinuous3DFunction() {
const int xsize = 10; const int xsize = 10;
const int ysize = 11; const int ysize = 11;
...@@ -504,6 +547,60 @@ void testContinuous3DFunction() { ...@@ -504,6 +547,60 @@ void testContinuous3DFunction() {
} }
} }
void testPeriodicContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 1.0;
const double xmax = 1.0+8.0*M_PI;
const double ymin = 0.0;
const double ymax = 2.0*M_PI;
const double zmin = 0.0;
const double zmax = 2.0*M_PI;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/(xsize-1);
double y = ymin + j*(ymax-ymin)/(ysize-1);
double z = zmin + k*(zmax-zmin)/(zsize-1);
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(y)*(1.0-sin(z));
}
}
}
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax, true));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin; x < xmax+0.2; x += 1.0) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.5) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.5) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setParameter("b", z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = sin(0.25*x)*cos(y)*(1.0-sin(z))+1.0;
double force = -0.25*cos(0.25*x)*cos(y)*(1.0-sin(z));
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
void testDiscrete1DFunction() { void testDiscrete1DFunction() {
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -1345,8 +1442,11 @@ int main(int argc, char* argv[]) { ...@@ -1345,8 +1442,11 @@ int main(int argc, char* argv[]) {
testPeriodic(); testPeriodic();
testTriclinic(); testTriclinic();
testContinuous1DFunction(); testContinuous1DFunction();
testPeriodicContinuous1DFunction();
testContinuous2DFunction(); testContinuous2DFunction();
testPeriodicContinuous2DFunction();
testContinuous3DFunction(); testContinuous3DFunction();
testPeriodicContinuous3DFunction();
testDiscrete1DFunction(); testDiscrete1DFunction();
testDiscrete2DFunction(); testDiscrete2DFunction();
testDiscrete3DFunction(); testDiscrete3DFunction();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment