Unverified Commit c3c8ec55 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Vectorized calculating long range correction coefficient (#3606)

parent b096cd7c
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "openmm/Kernel.h" #include "openmm/Kernel.h"
#include "openmm/internal/ThreadPool.h" #include "openmm/internal/ThreadPool.h"
#include "lepton/CompiledExpression.h" #include "lepton/CompiledExpression.h"
#include "lepton/CompiledVectorExpression.h"
#include <utility> #include <utility>
#include <map> #include <map>
#include <string> #include <string>
...@@ -69,7 +70,7 @@ public: ...@@ -69,7 +70,7 @@ public:
* the Context (such as global parameters). This allows the coefficient to be updated * the Context (such as global parameters). This allows the coefficient to be updated
* more quickly when global parameters change. * more quickly when global parameters change.
*/ */
static LongRangeCorrectionData prepareLongRangeCorrection(const CustomNonbondedForce& force); static LongRangeCorrectionData prepareLongRangeCorrection(const CustomNonbondedForce& force, int numThreads);
/** /**
* Compute the coefficient which, when divided by the periodic box volume, gives the * Compute the coefficient which, when divided by the periodic box volume, gives the
* long range correction to the energy. If the Force computes parameter derivatives, * long range correction to the energy. If the Force computes parameter derivatives,
...@@ -77,7 +78,7 @@ public: ...@@ -77,7 +78,7 @@ public:
*/ */
static void calcLongRangeCorrection(const CustomNonbondedForce& force, LongRangeCorrectionData& data, const Context& context, double& coefficient, std::vector<double>& derivatives, ThreadPool& threads); static void calcLongRangeCorrection(const CustomNonbondedForce& force, LongRangeCorrectionData& data, const Context& context, double& coefficient, std::vector<double>& derivatives, ThreadPool& threads);
private: private:
static double integrateInteraction(Lepton::CompiledExpression& expression, const std::vector<double>& params1, const std::vector<double>& params2, static double integrateInteraction(Lepton::CompiledVectorExpression& expression, const std::vector<double>& params1, const std::vector<double>& params2,
const std::vector<double>& computedValues1, const std::vector<double>& computedValues2, const CustomNonbondedForce& force, const Context& context, const std::vector<double>& computedValues1, const std::vector<double>& computedValues2, const CustomNonbondedForce& force, const Context& context,
const std::vector<std::string>& paramNames, const std::vector<std::string>& computedValueNames); const std::vector<std::string>& paramNames, const std::vector<std::string>& computedValueNames);
const CustomNonbondedForce& owner; const CustomNonbondedForce& owner;
...@@ -90,8 +91,9 @@ public: ...@@ -90,8 +91,9 @@ public:
std::vector<std::vector<double> > classes; std::vector<std::vector<double> > classes;
std::vector<std::string> paramNames, computedValueNames; std::vector<std::string> paramNames, computedValueNames;
std::map<std::pair<int, int>, long long int> interactionCount; std::map<std::pair<int, int>, long long int> interactionCount;
Lepton::CompiledExpression energyExpression; std::vector<Lepton::CompiledVectorExpression> energyExpression;
std::vector<Lepton::CompiledExpression> derivExpressions, computedValueExpressions; std::vector<std::vector<Lepton::CompiledVectorExpression> > derivExpressions;
std::vector<Lepton::CompiledExpression> computedValueExpressions;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -162,7 +162,7 @@ void CustomNonbondedForceImpl::updateParametersInContext(ContextImpl& context) { ...@@ -162,7 +162,7 @@ void CustomNonbondedForceImpl::updateParametersInContext(ContextImpl& context) {
context.systemChanged(); context.systemChanged();
} }
CustomNonbondedForceImpl::LongRangeCorrectionData CustomNonbondedForceImpl::prepareLongRangeCorrection(const CustomNonbondedForce& force) { CustomNonbondedForceImpl::LongRangeCorrectionData CustomNonbondedForceImpl::prepareLongRangeCorrection(const CustomNonbondedForce& force, int numThreads) {
LongRangeCorrectionData data; LongRangeCorrectionData data;
data.method = force.getNonbondedMethod(); data.method = force.getNonbondedMethod();
if (data.method == CustomNonbondedForce::NoCutoff || data.method == CustomNonbondedForce::CutoffNonPeriodic) if (data.method == CustomNonbondedForce::NoCutoff || data.method == CustomNonbondedForce::CutoffNonPeriodic)
...@@ -227,12 +227,19 @@ CustomNonbondedForceImpl::LongRangeCorrectionData CustomNonbondedForceImpl::prep ...@@ -227,12 +227,19 @@ CustomNonbondedForceImpl::LongRangeCorrectionData CustomNonbondedForceImpl::prep
// Prepare for evaluating the expressions. // Prepare for evaluating the expressions.
int width = Lepton::CompiledVectorExpression::getAllowedWidths().back();
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) for (int i = 0; i < force.getNumFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
data.energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).createCompiledExpression(); Lepton::CompiledVectorExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).createCompiledVectorExpression(width);
for (int k = 0; k < force.getNumEnergyParameterDerivatives(); k++) for (int i = 0; i < numThreads; i++)
data.derivExpressions.push_back(Lepton::Parser::parse(force.getEnergyFunction(), functions).differentiate(force.getEnergyParameterDerivativeName(k)).createCompiledExpression()); data.energyExpression.push_back(energyExpression);
data.derivExpressions.resize(numThreads);
for (int k = 0; k < force.getNumEnergyParameterDerivatives(); k++) {
Lepton::CompiledVectorExpression derivExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).differentiate(force.getEnergyParameterDerivativeName(k)).createCompiledVectorExpression(width);
for (int i = 0; i < numThreads; i++)
data.derivExpressions[i].push_back(derivExpression);
}
for (int i = 0; i < force.getNumPerParticleParameters(); i++) { for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
string name = force.getPerParticleParameterName(i); string name = force.getPerParticleParameterName(i);
data.paramNames.push_back(name+"1"); data.paramNames.push_back(name+"1");
...@@ -283,7 +290,7 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc ...@@ -283,7 +290,7 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
vector<double> threadSum(threads.getNumThreads(), 0.0); vector<double> threadSum(threads.getNumThreads(), 0.0);
atomic<int> atomicCounter(0); atomic<int> atomicCounter(0);
threads.execute([&] (ThreadPool& threads, int threadIndex) { threads.execute([&] (ThreadPool& threads, int threadIndex) {
Lepton::CompiledExpression expression = data.energyExpression; Lepton::CompiledVectorExpression& expression = data.energyExpression[threadIndex];
while (true) { while (true) {
int i = atomicCounter++; int i = atomicCounter++;
if (i >= numClasses) if (i >= numClasses)
...@@ -302,13 +309,13 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc ...@@ -302,13 +309,13 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
// Now do the same for parameter derivatives. // Now do the same for parameter derivatives.
int numDerivs = data.derivExpressions.size(); int numDerivs = data.derivExpressions[0].size();
derivatives.resize(numDerivs); derivatives.resize(numDerivs);
for (int k = 0; k < numDerivs; k++) { for (int k = 0; k < numDerivs; k++) {
atomicCounter = 0; atomicCounter = 0;
threads.execute([&] (ThreadPool& threads, int threadIndex) { threads.execute([&] (ThreadPool& threads, int threadIndex) {
threadSum[threadIndex] = 0; threadSum[threadIndex] = 0;
Lepton::CompiledExpression expression = data.derivExpressions[k]; Lepton::CompiledVectorExpression& expression = data.derivExpressions[threadIndex][k];
while (true) { while (true) {
int i = atomicCounter++; int i = atomicCounter++;
if (i >= numClasses) if (i >= numClasses)
...@@ -327,35 +334,51 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc ...@@ -327,35 +334,51 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
} }
} }
double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression& expression, const vector<double>& params1, const vector<double>& params2, double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledVectorExpression& expression, const vector<double>& params1, const vector<double>& params2,
const vector<double>& computedValues1, const vector<double>& computedValues2, const CustomNonbondedForce& force, const Context& context, const vector<double>& computedValues1, const vector<double>& computedValues2, const CustomNonbondedForce& force, const Context& context,
const vector<string>& paramNames, const vector<string>& computedValueNames) { const vector<string>& paramNames, const vector<string>& computedValueNames) {
int width = expression.getWidth();
const set<string>& variables = expression.getVariables(); const set<string>& variables = expression.getVariables();
for (int i = 0; i < force.getNumPerParticleParameters(); i++) { for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
if (variables.find(paramNames[2*i]) != variables.end()) if (variables.find(paramNames[2*i]) != variables.end()) {
expression.getVariableReference(paramNames[2*i]) = params1[i]; float* pointer = expression.getVariablePointer(paramNames[2*i]);
if (variables.find(paramNames[2*i+1]) != variables.end()) for (int j = 0; j < width; j++)
expression.getVariableReference(paramNames[2*i+1]) = params2[i]; pointer[j] = params1[i];
}
if (variables.find(paramNames[2*i+1]) != variables.end()) {
float* pointer = expression.getVariablePointer(paramNames[2*i+1]);
for (int j = 0; j < width; j++)
pointer[j] = params2[i];
}
} }
for (int i = 0; i < force.getNumComputedValues(); i++) { for (int i = 0; i < force.getNumComputedValues(); i++) {
if (variables.find(computedValueNames[2*i]) != variables.end()) if (variables.find(computedValueNames[2*i]) != variables.end()) {
expression.getVariableReference(computedValueNames[2*i]) = computedValues1[i]; float* pointer = expression.getVariablePointer(computedValueNames[2*i]);
if (variables.find(computedValueNames[2*i+1]) != variables.end()) for (int j = 0; j < width; j++)
expression.getVariableReference(computedValueNames[2*i+1]) = computedValues2[i]; pointer[j] = computedValues1[i];
}
if (variables.find(computedValueNames[2*i+1]) != variables.end()) {
float* pointer = expression.getVariablePointer(computedValueNames[2*i+1]);
for (int j = 0; j < width; j++)
pointer[j] = computedValues2[i];
}
} }
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
if (variables.find(name) != variables.end()) if (variables.find(name) != variables.end()) {
expression.getVariableReference(name) = context.getParameter(name); float* pointer = expression.getVariablePointer(name);
for (int j = 0; j < width; j++)
pointer[j] = context.getParameter(name);
}
} }
// To integrate from r_cutoff to infinity, make the change of variables x=r_cutoff/r and integrate from 0 to 1. // To integrate from r_cutoff to infinity, make the change of variables x=r_cutoff/r and integrate from 0 to 1.
// This introduces another r^2 into the integral, which along with the r^2 in the formula for the correction // This introduces another r^2 into the integral, which along with the r^2 in the formula for the correction
// means we multiply the function by r^4. Use the midpoint method. // means we multiply the function by r^4. Use the midpoint method.
double* rPointer; float* r;
try { try {
rPointer = &expression.getVariableReference("r"); r = expression.getVariablePointer("r");
} }
catch (exception& ex) { catch (exception& ex) {
throw OpenMMException("CustomNonbondedForce: Cannot use long range correction with a force that does not depend on r."); throw OpenMMException("CustomNonbondedForce: Cannot use long range correction with a force that does not depend on r.");
...@@ -366,14 +389,20 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression ...@@ -366,14 +389,20 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression
for (int iteration = 0; ; iteration++) { for (int iteration = 0; ; iteration++) {
double oldSum = sum; double oldSum = sum;
double newSum = 0; double newSum = 0;
int element = 0;
for (int i = 0; i < numPoints; i++) { for (int i = 0; i < numPoints; i++) {
if (i%3 == 1) if (i%3 != 1) {
continue;
double x = (i+0.5)/numPoints; double x = (i+0.5)/numPoints;
double r = cutoff/x; r[element++] = cutoff/x;
*rPointer = r; if (element == width || i == numPoints-1) {
double r2 = r*r; const float* result = expression.evaluate();
newSum += expression.evaluate()*r2*r2; for (int j = 0; j < element; j++) {
float r2 = r[j]*r[j];
newSum += result[j]*r2*r2;
}
element = 0;
}
}
} }
sum = newSum/numPoints + oldSum/3; sum = newSum/numPoints + oldSum/3;
double relativeChange = fabs((sum-oldSum)/sum); double relativeChange = fabs((sum-oldSum)/sum);
...@@ -391,17 +420,23 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression ...@@ -391,17 +420,23 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression
double rswitch = force.getSwitchingDistance(); double rswitch = force.getSwitchingDistance();
sum2 = 0; sum2 = 0;
numPoints = 1; numPoints = 1;
vector<double> switchValue(width);
for (int iteration = 0; ; iteration++) { for (int iteration = 0; ; iteration++) {
double oldSum = sum2; double oldSum = sum2;
double newSum = 0; double newSum = 0;
int element = 0;
for (int i = 0; i < numPoints; i++) { for (int i = 0; i < numPoints; i++) {
if (i%3 == 1) if (i%3 != 1) {
continue;
double x = (i+0.5)/numPoints; double x = (i+0.5)/numPoints;
double r = rswitch+x*(cutoff-rswitch); switchValue[element] = x*x*x*(10+x*(-15+x*6));
double switchValue = x*x*x*(10+x*(-15+x*6)); r[element++] = rswitch+x*(cutoff-rswitch);
*rPointer = r; if (element == width || i == numPoints-1) {
newSum += switchValue*expression.evaluate()*r*r; const float* result = expression.evaluate();
for (int j = 0; j < element; j++)
newSum += switchValue[j]*result[j]*r[j]*r[j];
element = 0;
}
}
} }
sum2 = newSum/numPoints + oldSum/3; sum2 = newSum/numPoints + oldSum/3;
double relativeChange = fabs((sum2-oldSum)/sum2); double relativeChange = fabs((sum2-oldSum)/sum2);
......
...@@ -2073,7 +2073,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -2073,7 +2073,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection() && cc.getContextIndex() == 0) { if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection() && cc.getContextIndex() == 0) {
forceCopy = new CustomNonbondedForce(force); forceCopy = new CustomNonbondedForce(force);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force); longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
cc.addPostComputation(new LongRangePostComputation(cc, longRangeCoefficient, longRangeCoefficientDerivs, forceCopy)); cc.addPostComputation(new LongRangePostComputation(cc, longRangeCoefficient, longRangeCoefficientDerivs, forceCopy));
hasInitializedLongRangeCorrection = false; hasInitializedLongRangeCorrection = false;
} }
...@@ -2449,7 +2449,7 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& ...@@ -2449,7 +2449,7 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
// If necessary, recompute the long range correction. // If necessary, recompute the long range correction.
if (forceCopy != NULL) { if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force); longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool()); CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool());
hasInitializedLongRangeCorrection = false; hasInitializedLongRangeCorrection = false;
*forceCopy = force; *forceCopy = force;
......
...@@ -1011,7 +1011,7 @@ double CpuCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool inc ...@@ -1011,7 +1011,7 @@ double CpuCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool inc
// Add in the long range correction. // Add in the long range correction.
if (!hasInitializedLongRangeCorrection) { if (!hasInitializedLongRangeCorrection) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy); longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy, data.threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, data.threads); CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, data.threads);
hasInitializedLongRangeCorrection = true; hasInitializedLongRangeCorrection = true;
} }
...@@ -1042,7 +1042,7 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con ...@@ -1042,7 +1042,7 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
// If necessary, recompute the long range correction. // If necessary, recompute the long range correction.
if (forceCopy != NULL) { if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force); longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, data.threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, data.threads); CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, data.threads);
hasInitializedLongRangeCorrection = true; hasInitializedLongRangeCorrection = true;
*forceCopy = force; *forceCopy = force;
......
...@@ -1303,8 +1303,8 @@ double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bo ...@@ -1303,8 +1303,8 @@ double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bo
// Add in the long range correction. // Add in the long range correction.
if (!hasInitializedLongRangeCorrection) { if (!hasInitializedLongRangeCorrection) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy);
ThreadPool threads; ThreadPool threads;
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy, threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads); CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
hasInitializedLongRangeCorrection = true; hasInitializedLongRangeCorrection = true;
} }
...@@ -1337,8 +1337,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp ...@@ -1337,8 +1337,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
// If necessary, recompute the long range correction. // If necessary, recompute the long range correction.
if (forceCopy != NULL) { if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force);
ThreadPool threads; ThreadPool threads;
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads); CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
hasInitializedLongRangeCorrection = true; hasInitializedLongRangeCorrection = true;
*forceCopy = force; *forceCopy = force;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment