Unverified Commit c3c8ec55 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Vectorized calculating long range correction coefficient (#3606)

parent b096cd7c
......@@ -37,6 +37,7 @@
#include "openmm/Kernel.h"
#include "openmm/internal/ThreadPool.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CompiledVectorExpression.h"
#include <utility>
#include <map>
#include <string>
......@@ -69,7 +70,7 @@ public:
* the Context (such as global parameters). This allows the coefficient to be updated
* more quickly when global parameters change.
*/
static LongRangeCorrectionData prepareLongRangeCorrection(const CustomNonbondedForce& force);
static LongRangeCorrectionData prepareLongRangeCorrection(const CustomNonbondedForce& force, int numThreads);
/**
* Compute the coefficient which, when divided by the periodic box volume, gives the
* long range correction to the energy. If the Force computes parameter derivatives,
......@@ -77,7 +78,7 @@ public:
*/
static void calcLongRangeCorrection(const CustomNonbondedForce& force, LongRangeCorrectionData& data, const Context& context, double& coefficient, std::vector<double>& derivatives, ThreadPool& threads);
private:
static double integrateInteraction(Lepton::CompiledExpression& expression, const std::vector<double>& params1, const std::vector<double>& params2,
static double integrateInteraction(Lepton::CompiledVectorExpression& expression, const std::vector<double>& params1, const std::vector<double>& params2,
const std::vector<double>& computedValues1, const std::vector<double>& computedValues2, const CustomNonbondedForce& force, const Context& context,
const std::vector<std::string>& paramNames, const std::vector<std::string>& computedValueNames);
const CustomNonbondedForce& owner;
......@@ -90,8 +91,9 @@ public:
std::vector<std::vector<double> > classes;
std::vector<std::string> paramNames, computedValueNames;
std::map<std::pair<int, int>, long long int> interactionCount;
Lepton::CompiledExpression energyExpression;
std::vector<Lepton::CompiledExpression> derivExpressions, computedValueExpressions;
std::vector<Lepton::CompiledVectorExpression> energyExpression;
std::vector<std::vector<Lepton::CompiledVectorExpression> > derivExpressions;
std::vector<Lepton::CompiledExpression> computedValueExpressions;
};
} // namespace OpenMM
......
......@@ -162,7 +162,7 @@ void CustomNonbondedForceImpl::updateParametersInContext(ContextImpl& context) {
context.systemChanged();
}
CustomNonbondedForceImpl::LongRangeCorrectionData CustomNonbondedForceImpl::prepareLongRangeCorrection(const CustomNonbondedForce& force) {
CustomNonbondedForceImpl::LongRangeCorrectionData CustomNonbondedForceImpl::prepareLongRangeCorrection(const CustomNonbondedForce& force, int numThreads) {
LongRangeCorrectionData data;
data.method = force.getNonbondedMethod();
if (data.method == CustomNonbondedForce::NoCutoff || data.method == CustomNonbondedForce::CutoffNonPeriodic)
......@@ -227,12 +227,19 @@ CustomNonbondedForceImpl::LongRangeCorrectionData CustomNonbondedForceImpl::prep
// Prepare for evaluating the expressions.
int width = Lepton::CompiledVectorExpression::getAllowedWidths().back();
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
data.energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).createCompiledExpression();
for (int k = 0; k < force.getNumEnergyParameterDerivatives(); k++)
data.derivExpressions.push_back(Lepton::Parser::parse(force.getEnergyFunction(), functions).differentiate(force.getEnergyParameterDerivativeName(k)).createCompiledExpression());
Lepton::CompiledVectorExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).createCompiledVectorExpression(width);
for (int i = 0; i < numThreads; i++)
data.energyExpression.push_back(energyExpression);
data.derivExpressions.resize(numThreads);
for (int k = 0; k < force.getNumEnergyParameterDerivatives(); k++) {
Lepton::CompiledVectorExpression derivExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).differentiate(force.getEnergyParameterDerivativeName(k)).createCompiledVectorExpression(width);
for (int i = 0; i < numThreads; i++)
data.derivExpressions[i].push_back(derivExpression);
}
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
string name = force.getPerParticleParameterName(i);
data.paramNames.push_back(name+"1");
......@@ -283,7 +290,7 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
vector<double> threadSum(threads.getNumThreads(), 0.0);
atomic<int> atomicCounter(0);
threads.execute([&] (ThreadPool& threads, int threadIndex) {
Lepton::CompiledExpression expression = data.energyExpression;
Lepton::CompiledVectorExpression& expression = data.energyExpression[threadIndex];
while (true) {
int i = atomicCounter++;
if (i >= numClasses)
......@@ -302,13 +309,13 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
// Now do the same for parameter derivatives.
int numDerivs = data.derivExpressions.size();
int numDerivs = data.derivExpressions[0].size();
derivatives.resize(numDerivs);
for (int k = 0; k < numDerivs; k++) {
atomicCounter = 0;
threads.execute([&] (ThreadPool& threads, int threadIndex) {
threadSum[threadIndex] = 0;
Lepton::CompiledExpression expression = data.derivExpressions[k];
Lepton::CompiledVectorExpression& expression = data.derivExpressions[threadIndex][k];
while (true) {
int i = atomicCounter++;
if (i >= numClasses)
......@@ -327,35 +334,51 @@ void CustomNonbondedForceImpl::calcLongRangeCorrection(const CustomNonbondedForc
}
}
double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression& expression, const vector<double>& params1, const vector<double>& params2,
double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledVectorExpression& expression, const vector<double>& params1, const vector<double>& params2,
const vector<double>& computedValues1, const vector<double>& computedValues2, const CustomNonbondedForce& force, const Context& context,
const vector<string>& paramNames, const vector<string>& computedValueNames) {
int width = expression.getWidth();
const set<string>& variables = expression.getVariables();
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
if (variables.find(paramNames[2*i]) != variables.end())
expression.getVariableReference(paramNames[2*i]) = params1[i];
if (variables.find(paramNames[2*i+1]) != variables.end())
expression.getVariableReference(paramNames[2*i+1]) = params2[i];
if (variables.find(paramNames[2*i]) != variables.end()) {
float* pointer = expression.getVariablePointer(paramNames[2*i]);
for (int j = 0; j < width; j++)
pointer[j] = params1[i];
}
if (variables.find(paramNames[2*i+1]) != variables.end()) {
float* pointer = expression.getVariablePointer(paramNames[2*i+1]);
for (int j = 0; j < width; j++)
pointer[j] = params2[i];
}
}
for (int i = 0; i < force.getNumComputedValues(); i++) {
if (variables.find(computedValueNames[2*i]) != variables.end())
expression.getVariableReference(computedValueNames[2*i]) = computedValues1[i];
if (variables.find(computedValueNames[2*i+1]) != variables.end())
expression.getVariableReference(computedValueNames[2*i+1]) = computedValues2[i];
if (variables.find(computedValueNames[2*i]) != variables.end()) {
float* pointer = expression.getVariablePointer(computedValueNames[2*i]);
for (int j = 0; j < width; j++)
pointer[j] = computedValues1[i];
}
if (variables.find(computedValueNames[2*i+1]) != variables.end()) {
float* pointer = expression.getVariablePointer(computedValueNames[2*i+1]);
for (int j = 0; j < width; j++)
pointer[j] = computedValues2[i];
}
}
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i);
if (variables.find(name) != variables.end())
expression.getVariableReference(name) = context.getParameter(name);
if (variables.find(name) != variables.end()) {
float* pointer = expression.getVariablePointer(name);
for (int j = 0; j < width; j++)
pointer[j] = context.getParameter(name);
}
}
// To integrate from r_cutoff to infinity, make the change of variables x=r_cutoff/r and integrate from 0 to 1.
// This introduces another r^2 into the integral, which along with the r^2 in the formula for the correction
// means we multiply the function by r^4. Use the midpoint method.
double* rPointer;
float* r;
try {
rPointer = &expression.getVariableReference("r");
r = expression.getVariablePointer("r");
}
catch (exception& ex) {
throw OpenMMException("CustomNonbondedForce: Cannot use long range correction with a force that does not depend on r.");
......@@ -366,14 +389,20 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression
for (int iteration = 0; ; iteration++) {
double oldSum = sum;
double newSum = 0;
int element = 0;
for (int i = 0; i < numPoints; i++) {
if (i%3 == 1)
continue;
double x = (i+0.5)/numPoints;
double r = cutoff/x;
*rPointer = r;
double r2 = r*r;
newSum += expression.evaluate()*r2*r2;
if (i%3 != 1) {
double x = (i+0.5)/numPoints;
r[element++] = cutoff/x;
if (element == width || i == numPoints-1) {
const float* result = expression.evaluate();
for (int j = 0; j < element; j++) {
float r2 = r[j]*r[j];
newSum += result[j]*r2*r2;
}
element = 0;
}
}
}
sum = newSum/numPoints + oldSum/3;
double relativeChange = fabs((sum-oldSum)/sum);
......@@ -383,25 +412,31 @@ double CustomNonbondedForceImpl::integrateInteraction(Lepton::CompiledExpression
throw OpenMMException("CustomNonbondedForce: Long range correction did not converge. Does the energy go to 0 faster than 1/r^2?");
numPoints *= 3;
}
// If a switching function is used, integrate over the switching interval.
double sum2 = 0;
if (force.getUseSwitchingFunction()) {
double rswitch = force.getSwitchingDistance();
sum2 = 0;
numPoints = 1;
vector<double> switchValue(width);
for (int iteration = 0; ; iteration++) {
double oldSum = sum2;
double newSum = 0;
int element = 0;
for (int i = 0; i < numPoints; i++) {
if (i%3 == 1)
continue;
double x = (i+0.5)/numPoints;
double r = rswitch+x*(cutoff-rswitch);
double switchValue = x*x*x*(10+x*(-15+x*6));
*rPointer = r;
newSum += switchValue*expression.evaluate()*r*r;
if (i%3 != 1) {
double x = (i+0.5)/numPoints;
switchValue[element] = x*x*x*(10+x*(-15+x*6));
r[element++] = rswitch+x*(cutoff-rswitch);
if (element == width || i == numPoints-1) {
const float* result = expression.evaluate();
for (int j = 0; j < element; j++)
newSum += switchValue[j]*result[j]*r[j]*r[j];
element = 0;
}
}
}
sum2 = newSum/numPoints + oldSum/3;
double relativeChange = fabs((sum2-oldSum)/sum2);
......
......@@ -2073,7 +2073,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection() && cc.getContextIndex() == 0) {
forceCopy = new CustomNonbondedForce(force);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
cc.addPostComputation(new LongRangePostComputation(cc, longRangeCoefficient, longRangeCoefficientDerivs, forceCopy));
hasInitializedLongRangeCorrection = false;
}
......@@ -2449,7 +2449,7 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool());
hasInitializedLongRangeCorrection = false;
*forceCopy = force;
......
......@@ -1011,7 +1011,7 @@ double CpuCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool inc
// Add in the long range correction.
if (!hasInitializedLongRangeCorrection) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy, data.threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, data.threads);
hasInitializedLongRangeCorrection = true;
}
......@@ -1042,7 +1042,7 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, data.threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, data.threads);
hasInitializedLongRangeCorrection = true;
*forceCopy = force;
......
......@@ -1303,8 +1303,8 @@ double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bo
// Add in the long range correction.
if (!hasInitializedLongRangeCorrection) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy);
ThreadPool threads;
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy, threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
hasInitializedLongRangeCorrection = true;
}
......@@ -1337,8 +1337,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force);
ThreadPool threads;
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
hasInitializedLongRangeCorrection = true;
*forceCopy = force;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment