Unverified Commit 7e1ebf1c authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Merge pull request #3813 from peastman/spline

Parallelize fitting 3D splines
parents e62448c4 cd36ad12
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2014 Stanford University and the Authors. * * Portions copyright (c) 2010-2022 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -29,11 +29,12 @@ ...@@ -29,11 +29,12 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include <vector>
#include <math.h>
#include "openmm/internal/SplineFitter.h" #include "openmm/internal/SplineFitter.h"
#include "openmm/internal/ThreadPool.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include <atomic>
#include <vector>
#include <cmath>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
...@@ -390,8 +391,8 @@ void SplineFitter::evaluate2DSplineDerivatives(const vector<double>& x, const ve ...@@ -390,8 +391,8 @@ void SplineFitter::evaluate2DSplineDerivatives(const vector<double>& x, const ve
} }
void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, bool periodic, vector<vector<double> >& c) { void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, bool periodic, vector<vector<double> >& c) {
int xsize = x.size(), ysize = y.size(), zsize = z.size(); const int xsize = x.size(), ysize = y.size(), zsize = z.size();
int xysize = xsize*ysize; const int xysize = xsize*ysize;
if (periodic) { if (periodic) {
if (xsize < 3 || ysize < 3 || zsize < 3) if (xsize < 3 || ysize < 3 || zsize < 3)
throw OpenMMException("create3DNaturalSpline: periodic spline must have at least three points along each axis"); throw OpenMMException("create3DNaturalSpline: periodic spline must have at least three points along each axis");
...@@ -402,103 +403,113 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -402,103 +403,113 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
throw OpenMMException("create2DNaturalSpline: incorrect number of values"); throw OpenMMException("create2DNaturalSpline: incorrect number of values");
vector<double> d1(xsize*ysize*zsize), d2(xsize*ysize*zsize), d3(xsize*ysize*zsize); vector<double> d1(xsize*ysize*zsize), d2(xsize*ysize*zsize), d3(xsize*ysize*zsize);
vector<double> d12(xsize*ysize*zsize), d13(xsize*ysize*zsize), d23(xsize*ysize*zsize), d123(xsize*ysize*zsize); vector<double> d12(xsize*ysize*zsize), d13(xsize*ysize*zsize), d23(xsize*ysize*zsize), d123(xsize*ysize*zsize);
vector<double> t(xsize), deriv(xsize); ThreadPool threads;
// Compute derivatives with respect to x. threads.execute([&] (ThreadPool& threads, int threadIndex) {
// Compute derivatives with respect to x.
for (int i = 0; i < ysize; i++) {
for (int j = 0; j < zsize; j++) { vector<double> t(xsize), deriv(xsize);
for (int k = 0; k < xsize; k++) for (int i = threadIndex; i < ysize; i += threads.getNumThreads()) {
t[k] = values[k+xsize*i+xysize*j]; for (int j = 0; j < zsize; j++) {
SplineFitter::createSpline(x, t, periodic, deriv); for (int k = 0; k < xsize; k++)
for (int k = 0; k < xsize; k++) t[k] = values[k+xsize*i+xysize*j];
d1[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]); SplineFitter::createSpline(x, t, periodic, deriv);
for (int k = 0; k < xsize; k++)
d1[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
}
} }
}
// Compute derivatives with respect to y.
t.resize(ysize); // Compute derivatives with respect to y.
deriv.resize(ysize);
for (int i = 0; i < xsize; i++) { t.resize(ysize);
for (int j = 0; j < zsize; j++) { deriv.resize(ysize);
for (int k = 0; k < ysize; k++) for (int i = threadIndex; i < xsize; i += threads.getNumThreads()) {
t[k] = values[i+xsize*k+xysize*j]; for (int j = 0; j < zsize; j++) {
SplineFitter::createSpline(y, t, periodic, deriv); for (int k = 0; k < ysize; k++)
for (int k = 0; k < ysize; k++) t[k] = values[i+xsize*k+xysize*j];
d2[i+xsize*k+xysize*j] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]); SplineFitter::createSpline(y, t, periodic, deriv);
for (int k = 0; k < ysize; k++)
d2[i+xsize*k+xysize*j] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]);
}
} }
}
// Compute derivatives with respect to z.
t.resize(zsize); // Compute derivatives with respect to z.
deriv.resize(zsize);
for (int i = 0; i < xsize; i++) { t.resize(zsize);
for (int j = 0; j < ysize; j++) { deriv.resize(zsize);
for (int k = 0; k < zsize; k++) for (int i = threadIndex; i < xsize; i += threads.getNumThreads()) {
t[k] = values[i+xsize*j+xysize*k]; for (int j = 0; j < ysize; j++) {
SplineFitter::createSpline(z, t, periodic, deriv); for (int k = 0; k < zsize; k++)
for (int k = 0; k < zsize; k++) t[k] = values[i+xsize*j+xysize*k];
d3[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]); SplineFitter::createSpline(z, t, periodic, deriv);
for (int k = 0; k < zsize; k++)
d3[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]);
}
} }
}
// Compute second derivatives with respect to x and y. // Compute second derivatives with respect to x and y.
t.resize(xsize); t.resize(xsize);
deriv.resize(xsize); deriv.resize(xsize);
for (int i = 0; i < ysize; i++) { threads.syncThreads();
for (int j = 0; j < zsize; j++) { for (int i = threadIndex; i < ysize; i += threads.getNumThreads()) {
for (int k = 0; k < xsize; k++) for (int j = 0; j < zsize; j++) {
t[k] = d2[k+xsize*i+xysize*j]; for (int k = 0; k < xsize; k++)
SplineFitter::createSpline(x, t, periodic, deriv); t[k] = d2[k+xsize*i+xysize*j];
for (int k = 0; k < xsize; k++) SplineFitter::createSpline(x, t, periodic, deriv);
d12[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]); for (int k = 0; k < xsize; k++)
d12[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
}
} }
}
// Compute second derivatives with respect to y and z.
t.resize(ysize); // Compute second derivatives with respect to y and z.
deriv.resize(ysize);
for (int i = 0; i < zsize; i++) { t.resize(ysize);
for (int j = 0; j < xsize; j++) { deriv.resize(ysize);
for (int k = 0; k < ysize; k++) for (int i = threadIndex; i < zsize; i += threads.getNumThreads()) {
t[k] = d3[j+xsize*k+xysize*i]; for (int j = 0; j < xsize; j++) {
SplineFitter::createSpline(y, t, periodic, deriv); for (int k = 0; k < ysize; k++)
for (int k = 0; k < ysize; k++) t[k] = d3[j+xsize*k+xysize*i];
d23[j+xsize*k+xysize*i] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]); SplineFitter::createSpline(y, t, periodic, deriv);
for (int k = 0; k < ysize; k++)
d23[j+xsize*k+xysize*i] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]);
}
} }
}
// Compute second derivatives with respect to x and z. // Compute second derivatives with respect to x and z.
t.resize(zsize); t.resize(zsize);
deriv.resize(zsize); deriv.resize(zsize);
for (int i = 0; i < xsize; i++) { for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) { for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) for (int k = 0; k < zsize; k++)
t[k] = d1[i+xsize*j+xysize*k]; t[k] = d1[i+xsize*j+xysize*k];
SplineFitter::createSpline(z, t, periodic, deriv); SplineFitter::createSpline(z, t, periodic, deriv);
for (int k = 0; k < zsize; k++) for (int k = 0; k < zsize; k++)
d13[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]); d13[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]);
}
} }
}
// Compute third derivatives with respect to x, y, and z.
t.resize(xsize); // Compute third derivatives with respect to x, y, and z.
deriv.resize(xsize);
for (int i = 0; i < ysize; i++) { t.resize(xsize);
for (int j = 0; j < zsize; j++) { deriv.resize(xsize);
for (int k = 0; k < xsize; k++) threads.syncThreads();
t[k] = d23[k+xsize*i+xysize*j]; for (int i = threadIndex; i < ysize; i += threads.getNumThreads()) {
SplineFitter::createSpline(x, t, periodic, deriv); for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
d123[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]); t[k] = d23[k+xsize*i+xysize*j];
SplineFitter::createSpline(x, t, periodic, deriv);
for (int k = 0; k < xsize; k++)
d123[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
}
} }
} });
threads.waitForThreads();
threads.resumeThreads();
threads.waitForThreads();
threads.resumeThreads();
threads.waitForThreads();
// Now compute the coefficients. This involves multiplying by a sparse 64x64 matrix, given // Now compute the coefficients. This involves multiplying by a sparse 64x64 matrix, given
// here in packed form. // here in packed form.
...@@ -578,49 +589,56 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -578,49 +589,56 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
weight[i].push_back(wt[index++]); weight[i].push_back(wt[index++]);
} }
} }
vector<double> rhs(64);
c.resize((xsize-1)*(ysize-1)*(zsize-1)); c.resize((xsize-1)*(ysize-1)*(zsize-1));
for (int i = 0; i < xsize-1; i++) { atomic<int> atomicCounter(0);
for (int j = 0; j < ysize-1; j++) { threads.execute([&] (ThreadPool& threads, int threadIndex) {
for (int k = 0; k < zsize-1; k++) { vector<double> rhs(64);
// Compute the 64 coefficients for patch (i, j, k). while (true) {
int i = atomicCounter++;
int nexti = i+1; if (i >= xsize-1)
int nextj = j+1; break;
int nextk = k+1; for (int j = 0; j < ysize-1; j++) {
double deltax = x[nexti]-x[i]; for (int k = 0; k < zsize-1; k++) {
double deltay = y[nextj]-y[j]; // Compute the 64 coefficients for patch (i, j, k).
double deltaz = z[nextk]-z[k];
double e[] = {values[i+j*xsize+k*xysize], values[nexti+j*xsize+k*xysize], values[i+nextj*xsize+k*xysize], values[nexti+nextj*xsize+k*xysize], values[i+j*xsize+nextk*xysize], values[nexti+j*xsize+nextk*xysize], values[i+nextj*xsize+nextk*xysize], values[nexti+nextj*xsize+nextk*xysize]}; int nexti = i+1;
double e1[] = {d1[i+j*xsize+k*xysize], d1[nexti+j*xsize+k*xysize], d1[i+nextj*xsize+k*xysize], d1[nexti+nextj*xsize+k*xysize], d1[i+j*xsize+nextk*xysize], d1[nexti+j*xsize+nextk*xysize], d1[i+nextj*xsize+nextk*xysize], d1[nexti+nextj*xsize+nextk*xysize]}; int nextj = j+1;
double e2[] = {d2[i+j*xsize+k*xysize], d2[nexti+j*xsize+k*xysize], d2[i+nextj*xsize+k*xysize], d2[nexti+nextj*xsize+k*xysize], d2[i+j*xsize+nextk*xysize], d2[nexti+j*xsize+nextk*xysize], d2[i+nextj*xsize+nextk*xysize], d2[nexti+nextj*xsize+nextk*xysize]}; int nextk = k+1;
double e3[] = {d3[i+j*xsize+k*xysize], d3[nexti+j*xsize+k*xysize], d3[i+nextj*xsize+k*xysize], d3[nexti+nextj*xsize+k*xysize], d3[i+j*xsize+nextk*xysize], d3[nexti+j*xsize+nextk*xysize], d3[i+nextj*xsize+nextk*xysize], d3[nexti+nextj*xsize+nextk*xysize]}; double deltax = x[nexti]-x[i];
double e12[] = {d12[i+j*xsize+k*xysize], d12[nexti+j*xsize+k*xysize], d12[i+nextj*xsize+k*xysize], d12[nexti+nextj*xsize+k*xysize], d12[i+j*xsize+nextk*xysize], d12[nexti+j*xsize+nextk*xysize], d12[i+nextj*xsize+nextk*xysize], d12[nexti+nextj*xsize+nextk*xysize]}; double deltay = y[nextj]-y[j];
double e13[] = {d13[i+j*xsize+k*xysize], d13[nexti+j*xsize+k*xysize], d13[i+nextj*xsize+k*xysize], d13[nexti+nextj*xsize+k*xysize], d13[i+j*xsize+nextk*xysize], d13[nexti+j*xsize+nextk*xysize], d13[i+nextj*xsize+nextk*xysize], d13[nexti+nextj*xsize+nextk*xysize]}; double deltaz = z[nextk]-z[k];
double e23[] = {d23[i+j*xsize+k*xysize], d23[nexti+j*xsize+k*xysize], d23[i+nextj*xsize+k*xysize], d23[nexti+nextj*xsize+k*xysize], d23[i+j*xsize+nextk*xysize], d23[nexti+j*xsize+nextk*xysize], d23[i+nextj*xsize+nextk*xysize], d23[nexti+nextj*xsize+nextk*xysize]}; double e[] = {values[i+j*xsize+k*xysize], values[nexti+j*xsize+k*xysize], values[i+nextj*xsize+k*xysize], values[nexti+nextj*xsize+k*xysize], values[i+j*xsize+nextk*xysize], values[nexti+j*xsize+nextk*xysize], values[i+nextj*xsize+nextk*xysize], values[nexti+nextj*xsize+nextk*xysize]};
double e123[] = {d123[i+j*xsize+k*xysize], d123[nexti+j*xsize+k*xysize], d123[i+nextj*xsize+k*xysize], d123[nexti+nextj*xsize+k*xysize], d123[i+j*xsize+nextk*xysize], d123[nexti+j*xsize+nextk*xysize], d123[i+nextj*xsize+nextk*xysize], d123[nexti+nextj*xsize+nextk*xysize]}; double e1[] = {d1[i+j*xsize+k*xysize], d1[nexti+j*xsize+k*xysize], d1[i+nextj*xsize+k*xysize], d1[nexti+nextj*xsize+k*xysize], d1[i+j*xsize+nextk*xysize], d1[nexti+j*xsize+nextk*xysize], d1[i+nextj*xsize+nextk*xysize], d1[nexti+nextj*xsize+nextk*xysize]};
for (int m = 0; m < 8; m++) { double e2[] = {d2[i+j*xsize+k*xysize], d2[nexti+j*xsize+k*xysize], d2[i+nextj*xsize+k*xysize], d2[nexti+nextj*xsize+k*xysize], d2[i+j*xsize+nextk*xysize], d2[nexti+j*xsize+nextk*xysize], d2[i+nextj*xsize+nextk*xysize], d2[nexti+nextj*xsize+nextk*xysize]};
rhs[m] = e[m]; double e3[] = {d3[i+j*xsize+k*xysize], d3[nexti+j*xsize+k*xysize], d3[i+nextj*xsize+k*xysize], d3[nexti+nextj*xsize+k*xysize], d3[i+j*xsize+nextk*xysize], d3[nexti+j*xsize+nextk*xysize], d3[i+nextj*xsize+nextk*xysize], d3[nexti+nextj*xsize+nextk*xysize]};
rhs[m+8] = e1[m]*deltax; double e12[] = {d12[i+j*xsize+k*xysize], d12[nexti+j*xsize+k*xysize], d12[i+nextj*xsize+k*xysize], d12[nexti+nextj*xsize+k*xysize], d12[i+j*xsize+nextk*xysize], d12[nexti+j*xsize+nextk*xysize], d12[i+nextj*xsize+nextk*xysize], d12[nexti+nextj*xsize+nextk*xysize]};
rhs[m+16] = e2[m]*deltay; double e13[] = {d13[i+j*xsize+k*xysize], d13[nexti+j*xsize+k*xysize], d13[i+nextj*xsize+k*xysize], d13[nexti+nextj*xsize+k*xysize], d13[i+j*xsize+nextk*xysize], d13[nexti+j*xsize+nextk*xysize], d13[i+nextj*xsize+nextk*xysize], d13[nexti+nextj*xsize+nextk*xysize]};
rhs[m+24] = e3[m]*deltaz; double e23[] = {d23[i+j*xsize+k*xysize], d23[nexti+j*xsize+k*xysize], d23[i+nextj*xsize+k*xysize], d23[nexti+nextj*xsize+k*xysize], d23[i+j*xsize+nextk*xysize], d23[nexti+j*xsize+nextk*xysize], d23[i+nextj*xsize+nextk*xysize], d23[nexti+nextj*xsize+nextk*xysize]};
rhs[m+32] = e12[m]*deltax*deltay; double e123[] = {d123[i+j*xsize+k*xysize], d123[nexti+j*xsize+k*xysize], d123[i+nextj*xsize+k*xysize], d123[nexti+nextj*xsize+k*xysize], d123[i+j*xsize+nextk*xysize], d123[nexti+j*xsize+nextk*xysize], d123[i+nextj*xsize+nextk*xysize], d123[nexti+nextj*xsize+nextk*xysize]};
rhs[m+40] = e13[m]*deltax*deltaz; for (int m = 0; m < 8; m++) {
rhs[m+48] = e23[m]*deltay*deltaz; rhs[m] = e[m];
rhs[m+56] = e123[m]*deltax*deltay*deltaz; rhs[m+8] = e1[m]*deltax;
} rhs[m+16] = e2[m]*deltay;
vector<double>& coeff = c[i+j*(xsize-1)+k*(xsize-1)*(ysize-1)]; rhs[m+24] = e3[m]*deltaz;
coeff.resize(64); rhs[m+32] = e12[m]*deltax*deltay;
for (int m = 0; m < 64; m++) { rhs[m+40] = e13[m]*deltax*deltaz;
double sum = 0.0; rhs[m+48] = e23[m]*deltay*deltaz;
int numElements = weight[m].size(); rhs[m+56] = e123[m]*deltax*deltay*deltaz;
for (int n = 0; n < numElements; n += 2) }
sum += weight[m][n+1]*rhs[weight[m][n]]; vector<double>& coeff = c[i+j*(xsize-1)+k*(xsize-1)*(ysize-1)];
coeff[m] = sum; coeff.resize(64);
for (int m = 0; m < 64; m++) {
double sum = 0.0;
int numElements = weight[m].size();
for (int n = 0; n < numElements; n += 2)
sum += weight[m][n+1]*rhs[weight[m][n]];
coeff[m] = sum;
}
} }
} }
} }
} });
threads.waitForThreads();
} }
void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, vector<vector<double> >& c) { void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, vector<vector<double> >& c) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment