Unverified Commit 26df7a87 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Cache coefficients for long range correction (#5239)

* Cache coefficients for long range correction

* updateParametersInContext() clears cache
parent 13c568d0
......@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -90,6 +90,8 @@ private:
std::vector<std::string> paramNames, computedValueNames;
std::vector<ComputeParameterInfo> paramBuffers, computedValueBuffers;
double longRangeCoefficient;
std::map<std::vector<float>, double> longRangeCoefficientCache;
std::map<std::vector<float>, std::vector<double> > longRangeCoefficientDerivsCache;
std::vector<double> longRangeCoefficientDerivs;
bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList, needGlobalParams;
int numGroupThreadBlocks;
......
......@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -110,20 +110,29 @@ private:
class CommonCalcCustomNonbondedForceKernel::LongRangeTask : public ComputeContext::WorkTask {
public:
LongRangeTask(ComputeContext& cc, Context& context, CustomNonbondedForceImpl::LongRangeCorrectionData& data,
double& longRangeCoefficient, vector<double>& longRangeCoefficientDerivs, CustomNonbondedForce* force) :
cc(cc), context(context), data(data), longRangeCoefficient(longRangeCoefficient),
longRangeCoefficientDerivs(longRangeCoefficientDerivs), force(force) {
LongRangeTask(ComputeContext& cc, Context& context, CustomNonbondedForceImpl::LongRangeCorrectionData& data, vector<float>& globalParamValues,
double& longRangeCoefficient, vector<double>& longRangeCoefficientDerivs, CustomNonbondedForce* force,
map<vector<float>, double>& longRangeCoefficientCache, map<vector<float>, vector<double> >& longRangeCoefficientDerivsCache) :
cc(cc), context(context), data(data), globalParamValues(globalParamValues), longRangeCoefficient(longRangeCoefficient),
longRangeCoefficientDerivs(longRangeCoefficientDerivs), force(force), longRangeCoefficientCache(longRangeCoefficientCache),
longRangeCoefficientDerivsCache(longRangeCoefficientDerivsCache) {
}
void execute() {
CustomNonbondedForceImpl::calcLongRangeCorrection(*force, data, context, longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool());
if (longRangeCoefficientCache.size() < 1000) {
longRangeCoefficientCache[globalParamValues] = longRangeCoefficient;
longRangeCoefficientDerivsCache[globalParamValues] = longRangeCoefficientDerivs;
}
}
private:
ComputeContext& cc;
Context& context;
CustomNonbondedForceImpl::LongRangeCorrectionData& data;
vector<float>& globalParamValues;
double& longRangeCoefficient;
vector<double>& longRangeCoefficientDerivs;
map<vector<float>, double>& longRangeCoefficientCache;
map<vector<float>, vector<double> >& longRangeCoefficientDerivsCache;
CustomNonbondedForce* force;
};
......@@ -639,9 +648,15 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool
globalParamValues[i] = value;
}
}
if (recomputeLongRangeCorrection && longRangeCoefficientCache.find(globalParamValues) != longRangeCoefficientCache.end()) {
longRangeCoefficient = longRangeCoefficientCache[globalParamValues];
longRangeCoefficientDerivs = longRangeCoefficientDerivsCache[globalParamValues];
recomputeLongRangeCorrection = false;
}
if (recomputeLongRangeCorrection) {
if (includeEnergy || forceCopy->getNumEnergyParameterDerivatives() > 0) {
cc.getWorkThread().addTask(new LongRangeTask(cc, context.getOwner(), longRangeCorrectionData, longRangeCoefficient, longRangeCoefficientDerivs, forceCopy));
cc.getWorkThread().addTask(new LongRangeTask(cc, context.getOwner(), longRangeCorrectionData, globalParamValues, longRangeCoefficient,
longRangeCoefficientDerivs, forceCopy, longRangeCoefficientCache, longRangeCoefficientDerivsCache));
hasInitializedLongRangeCorrection = true;
}
else
......@@ -723,15 +738,6 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
params->setParameterValuesSubset(firstParticle, paramVector);
}
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool());
hasInitializedLongRangeCorrection = false;
*forceCopy = force;
}
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
......@@ -744,6 +750,16 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
}
}
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
hasInitializedLongRangeCorrection = false;
*forceCopy = force;
longRangeCoefficientCache.clear();
longRangeCoefficientDerivsCache.clear();
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info, firstParticle <= lastParticle, false);
......
......@@ -1498,7 +1498,7 @@ void testComputedValues(int mode) {
Context context(system, integrator, platform);
context.setPositions(positions);
for (double lambda : {0.0, 0.3, 0.7, 1.0}) {
for (double lambda : {0.0, 0.3, 0.7, 1.0, 0.3}) { // Testing 0.3 twice checks caching of long range correction coefficient
context.setParameter("lambda", lambda);
double e1 = context.getState(State::Energy, false, 1<<0).getPotentialEnergy();
double e2 = context.getState(State::Energy, false, 1<<1).getPotentialEnergy();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment