Unverified Commit 26df7a87 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Cache coefficients for long range correction (#5239)

* Cache coefficients for long range correction

* updateParametersInContext() clears cache
parent 13c568d0
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -90,6 +90,8 @@ private: ...@@ -90,6 +90,8 @@ private:
std::vector<std::string> paramNames, computedValueNames; std::vector<std::string> paramNames, computedValueNames;
std::vector<ComputeParameterInfo> paramBuffers, computedValueBuffers; std::vector<ComputeParameterInfo> paramBuffers, computedValueBuffers;
double longRangeCoefficient; double longRangeCoefficient;
std::map<std::vector<float>, double> longRangeCoefficientCache;
std::map<std::vector<float>, std::vector<double> > longRangeCoefficientDerivsCache;
std::vector<double> longRangeCoefficientDerivs; std::vector<double> longRangeCoefficientDerivs;
bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList, needGlobalParams; bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList, needGlobalParams;
int numGroupThreadBlocks; int numGroupThreadBlocks;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -110,20 +110,29 @@ private: ...@@ -110,20 +110,29 @@ private:
class CommonCalcCustomNonbondedForceKernel::LongRangeTask : public ComputeContext::WorkTask { class CommonCalcCustomNonbondedForceKernel::LongRangeTask : public ComputeContext::WorkTask {
public: public:
LongRangeTask(ComputeContext& cc, Context& context, CustomNonbondedForceImpl::LongRangeCorrectionData& data, LongRangeTask(ComputeContext& cc, Context& context, CustomNonbondedForceImpl::LongRangeCorrectionData& data, vector<float>& globalParamValues,
double& longRangeCoefficient, vector<double>& longRangeCoefficientDerivs, CustomNonbondedForce* force) : double& longRangeCoefficient, vector<double>& longRangeCoefficientDerivs, CustomNonbondedForce* force,
cc(cc), context(context), data(data), longRangeCoefficient(longRangeCoefficient), map<vector<float>, double>& longRangeCoefficientCache, map<vector<float>, vector<double> >& longRangeCoefficientDerivsCache) :
longRangeCoefficientDerivs(longRangeCoefficientDerivs), force(force) { cc(cc), context(context), data(data), globalParamValues(globalParamValues), longRangeCoefficient(longRangeCoefficient),
longRangeCoefficientDerivs(longRangeCoefficientDerivs), force(force), longRangeCoefficientCache(longRangeCoefficientCache),
longRangeCoefficientDerivsCache(longRangeCoefficientDerivsCache) {
} }
void execute() { void execute() {
CustomNonbondedForceImpl::calcLongRangeCorrection(*force, data, context, longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool()); CustomNonbondedForceImpl::calcLongRangeCorrection(*force, data, context, longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool());
if (longRangeCoefficientCache.size() < 1000) {
longRangeCoefficientCache[globalParamValues] = longRangeCoefficient;
longRangeCoefficientDerivsCache[globalParamValues] = longRangeCoefficientDerivs;
}
} }
private: private:
ComputeContext& cc; ComputeContext& cc;
Context& context; Context& context;
CustomNonbondedForceImpl::LongRangeCorrectionData& data; CustomNonbondedForceImpl::LongRangeCorrectionData& data;
vector<float>& globalParamValues;
double& longRangeCoefficient; double& longRangeCoefficient;
vector<double>& longRangeCoefficientDerivs; vector<double>& longRangeCoefficientDerivs;
map<vector<float>, double>& longRangeCoefficientCache;
map<vector<float>, vector<double> >& longRangeCoefficientDerivsCache;
CustomNonbondedForce* force; CustomNonbondedForce* force;
}; };
...@@ -639,9 +648,15 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool ...@@ -639,9 +648,15 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool
globalParamValues[i] = value; globalParamValues[i] = value;
} }
} }
if (recomputeLongRangeCorrection && longRangeCoefficientCache.find(globalParamValues) != longRangeCoefficientCache.end()) {
longRangeCoefficient = longRangeCoefficientCache[globalParamValues];
longRangeCoefficientDerivs = longRangeCoefficientDerivsCache[globalParamValues];
recomputeLongRangeCorrection = false;
}
if (recomputeLongRangeCorrection) { if (recomputeLongRangeCorrection) {
if (includeEnergy || forceCopy->getNumEnergyParameterDerivatives() > 0) { if (includeEnergy || forceCopy->getNumEnergyParameterDerivatives() > 0) {
cc.getWorkThread().addTask(new LongRangeTask(cc, context.getOwner(), longRangeCorrectionData, longRangeCoefficient, longRangeCoefficientDerivs, forceCopy)); cc.getWorkThread().addTask(new LongRangeTask(cc, context.getOwner(), longRangeCorrectionData, globalParamValues, longRangeCoefficient,
longRangeCoefficientDerivs, forceCopy, longRangeCoefficientCache, longRangeCoefficientDerivsCache));
hasInitializedLongRangeCorrection = true; hasInitializedLongRangeCorrection = true;
} }
else else
...@@ -723,15 +738,6 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& ...@@ -723,15 +738,6 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
params->setParameterValuesSubset(firstParticle, paramVector); params->setParameterValuesSubset(firstParticle, paramVector);
} }
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool());
hasInitializedLongRangeCorrection = false;
*forceCopy = force;
}
// See if any tabulated functions have changed. // See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
...@@ -744,6 +750,16 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& ...@@ -744,6 +750,16 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
} }
} }
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
hasInitializedLongRangeCorrection = false;
*forceCopy = force;
longRangeCoefficientCache.clear();
longRangeCoefficientDerivsCache.clear();
}
// Mark that the current reordering may be invalid. // Mark that the current reordering may be invalid.
cc.invalidateMolecules(info, firstParticle <= lastParticle, false); cc.invalidateMolecules(info, firstParticle <= lastParticle, false);
......
...@@ -1498,7 +1498,7 @@ void testComputedValues(int mode) { ...@@ -1498,7 +1498,7 @@ void testComputedValues(int mode) {
Context context(system, integrator, platform); Context context(system, integrator, platform);
context.setPositions(positions); context.setPositions(positions);
for (double lambda : {0.0, 0.3, 0.7, 1.0}) { for (double lambda : {0.0, 0.3, 0.7, 1.0, 0.3}) { // Testing 0.3 twice checks caching of long range correction coefficient
context.setParameter("lambda", lambda); context.setParameter("lambda", lambda);
double e1 = context.getState(State::Energy, false, 1<<0).getPotentialEnergy(); double e1 = context.getState(State::Energy, false, 1<<0).getPotentialEnergy();
double e2 = context.getState(State::Energy, false, 1<<1).getPotentialEnergy(); double e2 = context.getState(State::Energy, false, 1<<1).getPotentialEnergy();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment