"vscode:/vscode.git/clone" did not exist on "83dcb4e2a40a612125a81a1459b246bc31cf49ba"
Commit cd36ad12 authored by peastman's avatar peastman
Browse files

Parallelize fitting 3D splines

parent f9cebbf0
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2014 Stanford University and the Authors. * * Portions copyright (c) 2010-2022 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -29,11 +29,12 @@ ...@@ -29,11 +29,12 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include <vector>
#include <math.h>
#include "openmm/internal/SplineFitter.h" #include "openmm/internal/SplineFitter.h"
#include "openmm/internal/ThreadPool.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include <atomic>
#include <vector>
#include <cmath>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
...@@ -390,8 +391,8 @@ void SplineFitter::evaluate2DSplineDerivatives(const vector<double>& x, const ve ...@@ -390,8 +391,8 @@ void SplineFitter::evaluate2DSplineDerivatives(const vector<double>& x, const ve
} }
void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, bool periodic, vector<vector<double> >& c) { void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, bool periodic, vector<vector<double> >& c) {
int xsize = x.size(), ysize = y.size(), zsize = z.size(); const int xsize = x.size(), ysize = y.size(), zsize = z.size();
int xysize = xsize*ysize; const int xysize = xsize*ysize;
if (periodic) { if (periodic) {
if (xsize < 3 || ysize < 3 || zsize < 3) if (xsize < 3 || ysize < 3 || zsize < 3)
throw OpenMMException("create3DNaturalSpline: periodic spline must have at least three points along each axis"); throw OpenMMException("create3DNaturalSpline: periodic spline must have at least three points along each axis");
...@@ -402,11 +403,13 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -402,11 +403,13 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
throw OpenMMException("create2DNaturalSpline: incorrect number of values"); throw OpenMMException("create2DNaturalSpline: incorrect number of values");
vector<double> d1(xsize*ysize*zsize), d2(xsize*ysize*zsize), d3(xsize*ysize*zsize); vector<double> d1(xsize*ysize*zsize), d2(xsize*ysize*zsize), d3(xsize*ysize*zsize);
vector<double> d12(xsize*ysize*zsize), d13(xsize*ysize*zsize), d23(xsize*ysize*zsize), d123(xsize*ysize*zsize); vector<double> d12(xsize*ysize*zsize), d13(xsize*ysize*zsize), d23(xsize*ysize*zsize), d123(xsize*ysize*zsize);
vector<double> t(xsize), deriv(xsize); ThreadPool threads;
threads.execute([&] (ThreadPool& threads, int threadIndex) {
// Compute derivatives with respect to x. // Compute derivatives with respect to x.
for (int i = 0; i < ysize; i++) { vector<double> t(xsize), deriv(xsize);
for (int i = threadIndex; i < ysize; i += threads.getNumThreads()) {
for (int j = 0; j < zsize; j++) { for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
t[k] = values[k+xsize*i+xysize*j]; t[k] = values[k+xsize*i+xysize*j];
...@@ -420,7 +423,7 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -420,7 +423,7 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
t.resize(ysize); t.resize(ysize);
deriv.resize(ysize); deriv.resize(ysize);
for (int i = 0; i < xsize; i++) { for (int i = threadIndex; i < xsize; i += threads.getNumThreads()) {
for (int j = 0; j < zsize; j++) { for (int j = 0; j < zsize; j++) {
for (int k = 0; k < ysize; k++) for (int k = 0; k < ysize; k++)
t[k] = values[i+xsize*k+xysize*j]; t[k] = values[i+xsize*k+xysize*j];
...@@ -434,7 +437,7 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -434,7 +437,7 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
t.resize(zsize); t.resize(zsize);
deriv.resize(zsize); deriv.resize(zsize);
for (int i = 0; i < xsize; i++) { for (int i = threadIndex; i < xsize; i += threads.getNumThreads()) {
for (int j = 0; j < ysize; j++) { for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) for (int k = 0; k < zsize; k++)
t[k] = values[i+xsize*j+xysize*k]; t[k] = values[i+xsize*j+xysize*k];
...@@ -448,7 +451,8 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -448,7 +451,8 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
t.resize(xsize); t.resize(xsize);
deriv.resize(xsize); deriv.resize(xsize);
for (int i = 0; i < ysize; i++) { threads.syncThreads();
for (int i = threadIndex; i < ysize; i += threads.getNumThreads()) {
for (int j = 0; j < zsize; j++) { for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
t[k] = d2[k+xsize*i+xysize*j]; t[k] = d2[k+xsize*i+xysize*j];
...@@ -462,7 +466,7 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -462,7 +466,7 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
t.resize(ysize); t.resize(ysize);
deriv.resize(ysize); deriv.resize(ysize);
for (int i = 0; i < zsize; i++) { for (int i = threadIndex; i < zsize; i += threads.getNumThreads()) {
for (int j = 0; j < xsize; j++) { for (int j = 0; j < xsize; j++) {
for (int k = 0; k < ysize; k++) for (int k = 0; k < ysize; k++)
t[k] = d3[j+xsize*k+xysize*i]; t[k] = d3[j+xsize*k+xysize*i];
...@@ -490,7 +494,8 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -490,7 +494,8 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
t.resize(xsize); t.resize(xsize);
deriv.resize(xsize); deriv.resize(xsize);
for (int i = 0; i < ysize; i++) { threads.syncThreads();
for (int i = threadIndex; i < ysize; i += threads.getNumThreads()) {
for (int j = 0; j < zsize; j++) { for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
t[k] = d23[k+xsize*i+xysize*j]; t[k] = d23[k+xsize*i+xysize*j];
...@@ -499,6 +504,12 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -499,6 +504,12 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
d123[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]); d123[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
} }
} }
});
threads.waitForThreads();
threads.resumeThreads();
threads.waitForThreads();
threads.resumeThreads();
threads.waitForThreads();
// Now compute the coefficients. This involves multiplying by a sparse 64x64 matrix, given // Now compute the coefficients. This involves multiplying by a sparse 64x64 matrix, given
// here in packed form. // here in packed form.
...@@ -578,9 +589,14 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -578,9 +589,14 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
weight[i].push_back(wt[index++]); weight[i].push_back(wt[index++]);
} }
} }
vector<double> rhs(64);
c.resize((xsize-1)*(ysize-1)*(zsize-1)); c.resize((xsize-1)*(ysize-1)*(zsize-1));
for (int i = 0; i < xsize-1; i++) { atomic<int> atomicCounter(0);
threads.execute([&] (ThreadPool& threads, int threadIndex) {
vector<double> rhs(64);
while (true) {
int i = atomicCounter++;
if (i >= xsize-1)
break;
for (int j = 0; j < ysize-1; j++) { for (int j = 0; j < ysize-1; j++) {
for (int k = 0; k < zsize-1; k++) { for (int k = 0; k < zsize-1; k++) {
// Compute the 64 coefficients for patch (i, j, k). // Compute the 64 coefficients for patch (i, j, k).
...@@ -621,6 +637,8 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& ...@@ -621,6 +637,8 @@ void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>&
} }
} }
} }
});
threads.waitForThreads();
} }
void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, vector<vector<double> >& c) { void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, vector<vector<double> >& c) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment