/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see . *
* -------------------------------------------------------------------------- */
#include "openmm/common/CommonCalcCustomNonbondedForceKernel.h"
#include "openmm/common/CommonKernelUtilities.h"
#include "openmm/common/ContextSelector.h"
#include "openmm/common/ExpressionUtilities.h"
#include "openmm/Context.h"
#include "openmm/internal/ContextImpl.h"
#include "CommonKernelSources.h"
#include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h"
#include "lepton/Operation.h"
#include "lepton/Parser.h"
#include "lepton/ParsedExpression.h"
using namespace OpenMM;
using namespace std;
using namespace Lepton;
class CommonCalcCustomNonbondedForceKernel::ForceInfo : public ComputeForceInfo {
public:
ForceInfo(const CustomNonbondedForce& force) : force(force) {
if (force.getNumInteractionGroups() > 0) {
groupsForParticle.resize(force.getNumParticles());
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
for (int p : set1)
groupsForParticle[p].insert(2*i);
for (int p : set2)
groupsForParticle[p].insert(2*i+1);
}
}
}
bool areParticlesIdentical(int particle1, int particle2) {
thread_local static vector params1, params2;
force.getParticleParameters(particle1, params1);
force.getParticleParameters(particle2, params2);
for (int i = 0; i < (int) params1.size(); i++)
if (params1[i] != params2[i])
return false;
if (groupsForParticle.size() > 0 && groupsForParticle[particle1] != groupsForParticle[particle2])
return false;
return true;
}
int getNumParticleGroups() {
return force.getNumExclusions();
}
void getParticlesInGroup(int index, vector& particles) {
int particle1, particle2;
force.getExclusionParticles(index, particle1, particle2);
particles.resize(2);
particles[0] = particle1;
particles[1] = particle2;
}
bool areGroupsIdentical(int group1, int group2) {
return true;
}
private:
const CustomNonbondedForce& force;
vector > groupsForParticle;
};
class CommonCalcCustomNonbondedForceKernel::LongRangePostComputation : public ComputeContext::ForcePostComputation {
public:
LongRangePostComputation(ComputeContext& cc, double& longRangeCoefficient, vector& longRangeCoefficientDerivs, CustomNonbondedForce* force) :
cc(cc), longRangeCoefficient(longRangeCoefficient), longRangeCoefficientDerivs(longRangeCoefficientDerivs), force(force) {
}
double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<getForceGroup())) == 0)
return 0;
if (!cc.getWorkThread().isCurrentThread())
cc.getWorkThread().flush();
Vec3 a, b, c;
cc.getPeriodicBoxVectors(a, b, c);
double volume = a[0]*b[1]*c[2];
map& derivs = cc.getEnergyParamDerivWorkspace();
for (int i = 0; i < longRangeCoefficientDerivs.size(); i++)
derivs[force->getEnergyParameterDerivativeName(i)] += longRangeCoefficientDerivs[i]/volume;
return longRangeCoefficient/volume;
}
private:
ComputeContext& cc;
double& longRangeCoefficient;
vector& longRangeCoefficientDerivs;
CustomNonbondedForce* force;
};
class CommonCalcCustomNonbondedForceKernel::LongRangeTask : public ComputeContext::WorkTask {
public:
LongRangeTask(ComputeContext& cc, Context& context, CustomNonbondedForceImpl::LongRangeCorrectionData& data, vector& globalParamValues,
double& longRangeCoefficient, vector& longRangeCoefficientDerivs, CustomNonbondedForce* force,
map, double>& longRangeCoefficientCache, map, vector >& longRangeCoefficientDerivsCache) :
cc(cc), context(context), data(data), globalParamValues(globalParamValues), longRangeCoefficient(longRangeCoefficient),
longRangeCoefficientDerivs(longRangeCoefficientDerivs), force(force), longRangeCoefficientCache(longRangeCoefficientCache),
longRangeCoefficientDerivsCache(longRangeCoefficientDerivsCache) {
}
void execute() {
CustomNonbondedForceImpl::calcLongRangeCorrection(*force, data, context, longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool());
if (longRangeCoefficientCache.size() < 1000) {
longRangeCoefficientCache[globalParamValues] = longRangeCoefficient;
longRangeCoefficientDerivsCache[globalParamValues] = longRangeCoefficientDerivs;
}
}
private:
ComputeContext& cc;
Context& context;
CustomNonbondedForceImpl::LongRangeCorrectionData& data;
vector& globalParamValues;
double& longRangeCoefficient;
vector& longRangeCoefficientDerivs;
map, double>& longRangeCoefficientCache;
map, vector >& longRangeCoefficientDerivsCache;
CustomNonbondedForce* force;
};
CommonCalcCustomNonbondedForceKernel::~CommonCalcCustomNonbondedForceKernel() {
ContextSelector selector(cc);
if (params != NULL)
delete params;
if (computedValues != NULL)
delete computedValues;
if (forceCopy != NULL)
delete forceCopy;
}
void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, const CustomNonbondedForce& force) {
ContextSelector selector(cc);
int forceIndex;
for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex)
;
string prefix = (force.getNumInteractionGroups() == 0 ? "custom"+cc.intToString(forceIndex)+"_" : "");
// Record parameters and exclusions.
int numParticles = force.getNumParticles();
int paddedNumParticles = cc.getPaddedNumAtoms();
int numParams = force.getNumPerParticleParameters();
params = new ComputeParameterSet(cc, numParams, paddedNumParticles, "customNonbondedParameters", true);
vector > paramVector(paddedNumParticles, vector(numParams, 0));
vector > exclusionList(numParticles);
for (int i = 0; i < numParticles; i++) {
vector parameters;
force.getParticleParameters(i, parameters);
paramVector[i].resize(parameters.size());
for (int j = 0; j < (int) parameters.size(); j++)
paramVector[i][j] = (float) parameters[j];
exclusionList[i].push_back(i);
}
for (int i = 0; i < force.getNumExclusions(); i++) {
int particle1, particle2;
force.getExclusionParticles(i, particle1, particle2);
exclusionList[particle1].push_back(particle2);
exclusionList[particle2].push_back(particle1);
}
params->setParameterValues(paramVector);
// Record the tabulated functions.
map functions;
vector > functionDefinitions;
vector functionList;
vector tableTypes;
stringstream tableArgs;
tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].initialize(cc, f.size(), "TabulatedFunction");
tabulatedFunctionArrays[i].upload(f);
if (force.getNumInteractionGroups() == 0)
cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(tabulatedFunctionArrays[i], arrayName, "float", width));
if (width == 1)
tableTypes.push_back("float");
else
tableTypes.push_back("float"+cc.intToString(width));
tableArgs << ", GLOBAL const float";
if (width > 1)
tableArgs << width;
tableArgs << "* RESTRICT " << arrayName;
}
// Record information for the expressions.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
bool useCutoff = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff);
bool usePeriodic = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && force.getNonbondedMethod() != CustomNonbondedForce::CutoffNonPeriodic);
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
map forceExpressions;
forceExpressions["real customEnergy = "] = energyExpression;
forceExpressions["tempForce -= "] = forceExpression;
// Record which per-particle parameters and computed values appear in the energy expression.
if (force.getNumComputedValues() > 0)
computedValues = new ComputeParameterSet(cc, force.getNumComputedValues(), paddedNumParticles, "customNonbondedComputedValues", true);
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
string name = force.getPerParticleParameterName(i);
if (usesVariable(energyExpression, name+"1") || usesVariable(energyExpression, name+"2")) {
paramNames.push_back(name);
paramBuffers.push_back(params->getParameterInfos()[i]);
}
}
for (int i = 0; i < force.getNumComputedValues(); i++) {
string name, expression;
force.getComputedValueParameters(i, name, expression);
if (usesVariable(energyExpression, name+"1") || usesVariable(energyExpression, name+"2")) {
computedValueNames.push_back(name);
computedValueBuffers.push_back(computedValues->getParameterInfos()[i]);
}
}
// Create the kernels.
vector > variables;
ExpressionTreeNode rnode(new Operation::Variable("r"));
variables.push_back(make_pair(rnode, "r"));
variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
for (int i = 0; i < paramNames.size(); i++) {
variables.push_back(makeVariable(paramNames[i]+"1", "((real) "+prefix+"params"+cc.intToString(i+1)+"1)"));
variables.push_back(makeVariable(paramNames[i]+"2", "((real) "+prefix+"params"+cc.intToString(i+1)+"2)"));
}
for (int i = 0; i < computedValueNames.size(); i++) {
variables.push_back(makeVariable(computedValueNames[i]+"1", prefix+"values"+cc.intToString(i+1)+"1"));
variables.push_back(makeVariable(computedValueNames[i]+"2", prefix+"values"+cc.intToString(i+1)+"2"));
}
needGlobalParams = (force.getNumGlobalParameters() > 0);
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i);
int index = cc.registerGlobalParam(name);
string value = "globals["+cc.intToString(index)+"]";
variables.push_back(makeVariable(name, prefix+value));
}
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string paramName = force.getEnergyParameterDerivativeName(i);
string derivVariable = cc.getNonbondedUtilities().addEnergyParameterDerivative(paramName);
Lepton::ParsedExpression derivExpression = energyExpression.differentiate(paramName).optimize();
forceExpressions[derivVariable+" += interactionScale*switchValue*"] = derivExpression;
}
stringstream compute;
compute << cc.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, prefix+"temp");
map replacements;
replacements["COMPUTE_FORCE"] = compute.str();
replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
if (force.getUseSwitchingFunction()) {
// Compute the switching coefficients.
replacements["SWITCH_CUTOFF"] = cc.doubleToString(force.getSwitchingDistance());
replacements["SWITCH_C3"] = cc.doubleToString(10/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 3.0));
replacements["SWITCH_C4"] = cc.doubleToString(15/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 4.0));
replacements["SWITCH_C5"] = cc.doubleToString(6/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 5.0));
}
string source = cc.replaceStrings(CommonKernelSources::customNonbonded, replacements);
if (force.getNumInteractionGroups() > 0)
initInteractionGroups(force, source, tableTypes);
else {
cc.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup(), numParticles > 2000);
for (int i = 0; i < paramBuffers.size(); i++)
cc.getNonbondedUtilities().addParameter(ComputeParameterInfo(paramBuffers[i].getArray(), prefix+"params"+cc.intToString(i+1),
paramBuffers[i].getComponentType(), paramBuffers[i].getNumComponents()));
for (int i = 0; i < computedValueBuffers.size(); i++)
cc.getNonbondedUtilities().addParameter(ComputeParameterInfo(computedValueBuffers[i].getArray(), prefix+"values"+cc.intToString(i+1),
computedValueBuffers[i].getComponentType(), computedValueBuffers[i].getNumComponents()));
if (needGlobalParams)
cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(cc.getGlobalParamValues(), prefix+"globals", "real", 1));
}
if (force.getNumComputedValues() > 0) {
// Create the kernel to calculate computed values.
stringstream valuesSource, args;
for (int i = 0; i < computedValues->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = computedValues->getParameterInfos()[i];
string valueName = "values"+cc.intToString(i+1);
if (i > 0)
args << ", ";
args << "GLOBAL " << buffer.getType() << "* RESTRICT global_" << valueName;
valuesSource << buffer.getType() << " local_" << valueName << ";\n";
}
if (force.getNumGlobalParameters() > 0)
args << ", GLOBAL const real* globals";
for (int i = 0; i < params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1);
args << ", GLOBAL const " << buffer.getType() << "* RESTRICT " << paramName;
}
map variables;
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i);
int index = cc.registerGlobalParam(name);
variables[name] = "globals["+cc.intToString(index)+"]";
}
for (int i = 0; i < force.getNumComputedValues(); i++) {
string name, expression;
force.getComputedValueParameters(i, name, expression);
variables[name] = "local_values"+computedValues->getParameterSuffix(i);
map valueExpressions;
valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(expression, functions).optimize();
valuesSource << cc.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cc.intToString(i)+"_temp");
}
for (int i = 0; i < (int) computedValues->getParameterInfos().size(); i++) {
string valueName = "values"+cc.intToString(i+1);
valuesSource << "global_" << valueName << "[index] = local_" << valueName << ";\n";
}
map replacements;
replacements["PARAMETER_ARGUMENTS"] = args.str()+tableArgs.str();
replacements["COMPUTE_VALUES"] = valuesSource.str();
map defines;
defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::customNonbondedComputedValues, replacements), defines);
computedValuesKernel = program->createKernel("computePerParticleValues");
for (auto& value : computedValues->getParameterInfos())
computedValuesKernel->addArg(value.getArray());
if (needGlobalParams)
computedValuesKernel->addArg();
for (auto& parameter : params->getParameterInfos())
computedValuesKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctionArrays)
computedValuesKernel->addArg(function);
}
info = new ForceInfo(force);
cc.addForce(info);
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection() && cc.getContextIndex() == 0) {
forceCopy = new CustomNonbondedForce(force);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
cc.addPostComputation(new LongRangePostComputation(cc, longRangeCoefficient, longRangeCoefficientDerivs, forceCopy));
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
}
void CommonCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource, const vector& tableTypes) {
// Process groups to form tiles.
vector > atomLists;
vector > tiles;
vector tileGroup;
vector > duplicateAtomsForGroup;
for (int group = 0; group < force.getNumInteractionGroups(); group++) {
// Get the list of atoms in this group and sort them.
set set1, set2;
force.getInteractionGroupParameters(group, set1, set2);
vector atoms1, atoms2;
atoms1.insert(atoms1.begin(), set1.begin(), set1.end());
atoms2.insert(atoms2.begin(), set2.begin(), set2.end());
sort(atoms1.begin(), atoms1.end());
sort(atoms2.begin(), atoms2.end());
duplicateAtomsForGroup.push_back(vector());
set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(),
inserter(duplicateAtomsForGroup[group], duplicateAtomsForGroup[group].begin()));
sort(duplicateAtomsForGroup[group].begin(), duplicateAtomsForGroup[group].end());
// Find how many tiles we will create for this group.
int tileWidth = min(min(32, (int) atoms1.size()), (int) atoms2.size());
if (tileWidth == 0)
continue;
int numBlocks1 = (atoms1.size()+tileWidth-1)/tileWidth;
int numBlocks2 = (atoms2.size()+tileWidth-1)/tileWidth;
// Add the tiles.
int firstTile = tiles.size();
for (int i = 0; i < numBlocks1; i++)
for (int j = 0; j < numBlocks2; j++) {
tiles.push_back(make_pair(atomLists.size()+i, atomLists.size()+numBlocks1+j));
tileGroup.push_back(group);
}
// Add the atom lists.
for (int i = 0; i < numBlocks1; i++) {
vector atoms;
int first = i*tileWidth;
int last = min((i+1)*tileWidth, (int) atoms1.size());
for (int j = first; j < last; j++)
atoms.push_back(atoms1[j]);
atomLists.push_back(atoms);
}
for (int i = 0; i < numBlocks2; i++) {
vector atoms;
int first = i*tileWidth;
int last = min((i+1)*tileWidth, (int) atoms2.size());
for (int j = first; j < last; j++)
atoms.push_back(atoms2[j]);
atomLists.push_back(atoms);
}
}
// Build a lookup table for quickly identifying excluded interactions.
vector > exclusions(force.getNumParticles());
for (int i = 0; i < force.getNumExclusions(); i++) {
int p1, p2;
force.getExclusionParticles(i, p1, p2);
exclusions[p1].insert(p2);
exclusions[p2].insert(p1);
}
// Build the exclusion flags for each tile. While we're at it, filter out tiles
// where all interactions are excluded, and sort the tiles by size.
vector > exclusionFlags(tiles.size());
vector > tileOrder;
for (int tile = 0; tile < tiles.size(); tile++) {
bool swapped = false;
if (atomLists[tiles[tile].first].size() < atomLists[tiles[tile].second].size()) {
// For efficiency, we want the first axis to be the larger one.
int swap = tiles[tile].first;
tiles[tile].first = tiles[tile].second;
tiles[tile].second = swap;
swapped = true;
}
vector& atoms1 = atomLists[tiles[tile].first];
vector& atoms2 = atomLists[tiles[tile].second];
vector& duplicateAtoms = duplicateAtomsForGroup[tileGroup[tile]];
vector& flags = exclusionFlags[tile];
flags.resize(atoms1.size(), (int) (1LL< a2) == swapped && a1IsDuplicate && binary_search(duplicateAtoms.begin(), duplicateAtoms.end(), a2))
isExcluded = true; // Both atoms are in both sets, so skip duplicate interactions.
if (isExcluded) {
flags[i] &= -1-(1< tileSetStart;
tileSetStart.push_back(0);
int tileSetSize = 0;
for (int i = 0; i < tileOrder.size(); i++) {
int tile = tileOrder[i].second;
int size = atomLists[tiles[tile].first].size();
if (tileSetSize+size > 32) {
tileSetStart.push_back(i);
tileSetSize = 0;
}
tileSetSize += size;
}
tileSetStart.push_back(tileOrder.size());
// Build the data structures.
int numTileSets = tileSetStart.size()-1;
vector groupData;
for (int tileSet = 0; tileSet < numTileSets; tileSet++) {
int indexInTileSet = 0;
int minSize = 0;
if (cc.getSIMDWidth() < 32) {
// We need to include a barrier inside the inner loop, so ensure that all
// threads will loop the same number of times.
for (int i = tileSetStart[tileSet]; i < tileSetStart[tileSet+1]; i++)
minSize = max(minSize, (int) atomLists[tiles[tileOrder[i].second].first].size());
}
for (int i = tileSetStart[tileSet]; i < tileSetStart[tileSet+1]; i++) {
int tile = tileOrder[i].second;
vector& atoms1 = atomLists[tiles[tile].first];
vector& atoms2 = atomLists[tiles[tile].second];
int range = indexInTileSet + ((indexInTileSet+max(minSize, (int) atoms1.size()))<<16);
int allFlags = (1< 0 ? exclusionFlags[tile][j] : allFlags);
groupData.push_back(mm_int4(a1, a2, range, flags<(cc, groupData.size(), "interactionGroupData");
interactionGroupData.upload(groupData);
numGroupTiles.initialize(cc, 1, "numGroupTiles");
// Allocate space for a neighbor list, if necessary.
if (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && groupData.size() > cc.getNumThreadBlocks()) {
filteredGroupData.initialize(cc, groupData.size(), "filteredGroupData");
interactionGroupData.copyTo(filteredGroupData);
int numTiles = groupData.size()/32;
numGroupTiles.upload(&numTiles);
}
// Create the kernel.
hasParamDerivs = (force.getNumEnergyParameterDerivatives() > 0);
map replacements;
replacements["COMPUTE_INTERACTION"] = interactionSource;
const string suffixes[] = {"x", "y", "z", "w"};
stringstream localData;
int localDataSize = 0;
for (int i = 0; i < paramBuffers.size(); i++) {
localData<& allParamDerivNames = cc.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string paramName = force.getEnergyParameterDerivativeName(i);
string derivVariable = cc.getNonbondedUtilities().addEnergyParameterDerivative(paramName);
initDerivs<<"mixed "< defines;
if (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff)
defines["USE_CUTOFF"] = "1";
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic)
defines["USE_PERIODIC"] = "1";
int localMemorySize = max(32, cc.getNonbondedUtilities().getForceThreadBlockSize());
defines["LOCAL_MEMORY_SIZE"] = cc.intToString(localMemorySize);
defines["WARPS_IN_BLOCK"] = cc.intToString(localMemorySize/32);
double cutoff = force.getCutoffDistance();
defines["CUTOFF_SQUARED"] = cc.doubleToString(cutoff*cutoff);
double paddedCutoff = cc.getNonbondedUtilities().padCutoff(cutoff);
defines["PADDED_CUTOFF_SQUARED"] = cc.doubleToString(paddedCutoff*paddedCutoff);
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
defines["TILE_SIZE"] = "32";
defines["NUM_TILES"] = cc.intToString(numTileSets);
int numContexts = cc.getNumContexts();
int startIndex = cc.getContextIndex()*numTileSets/numContexts;
int endIndex = (cc.getContextIndex()+1)*numTileSets/numContexts;
defines["FIRST_TILE"] = cc.intToString(startIndex);
defines["LAST_TILE"] = cc.intToString(endIndex);
if ((localDataSize/4)%2 == 0 && !cc.getUseDoublePrecision())
defines["PARAMETER_SIZE_IS_EVEN"] = "1";
ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::customNonbondedGroups, replacements), defines);
interactionGroupKernel = program->createKernel("computeInteractionGroups");
prepareNeighborListKernel = program->createKernel("prepareToBuildNeighborList");
buildNeighborListKernel = program->createKernel("buildNeighborList");
numGroupThreadBlocks = cc.getNonbondedUtilities().getNumForceThreadBlocks();
}
double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
useNeighborList = (filteredGroupData.isInitialized() && cc.getNonbondedUtilities().getUseCutoff());
if (useNeighborList && cc.getContextIndex() > 0) {
// When using a neighbor list, run the whole calculation on a single device.
return 0.0;
}
ContextSelector selector(cc);
bool recomputeLongRangeCorrection = !hasInitializedLongRangeCorrection;
if (needGlobalParams && forceCopy != NULL) {
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
recomputeLongRangeCorrection = true;
globalParamValues[i] = value;
}
}
if (recomputeLongRangeCorrection && longRangeCoefficientCache.find(globalParamValues) != longRangeCoefficientCache.end()) {
longRangeCoefficient = longRangeCoefficientCache[globalParamValues];
longRangeCoefficientDerivs = longRangeCoefficientDerivsCache[globalParamValues];
recomputeLongRangeCorrection = false;
}
if (recomputeLongRangeCorrection) {
if (includeEnergy || forceCopy->getNumEnergyParameterDerivatives() > 0) {
cc.getWorkThread().addTask(new LongRangeTask(cc, context.getOwner(), longRangeCorrectionData, globalParamValues, longRangeCoefficient,
longRangeCoefficientDerivs, forceCopy, longRangeCoefficientCache, longRangeCoefficientDerivsCache));
hasInitializedLongRangeCorrection = true;
}
else
hasInitializedLongRangeCorrection = false;
}
if (computedValues != NULL) {
computedValuesKernel->setArg(computedValues->getParameterInfos().size(), cc.getGlobalParamValues());
computedValuesKernel->execute(cc.getNumAtoms());
}
if (interactionGroupData.isInitialized()) {
if (!hasInitializedKernel) {
hasInitializedKernel = true;
interactionGroupKernel->addArg(cc.getLongForceBuffer());
interactionGroupKernel->addArg(cc.getEnergyBuffer());
interactionGroupKernel->addArg(cc.getPosq());
interactionGroupKernel->addArg((useNeighborList ? filteredGroupData : interactionGroupData));
interactionGroupKernel->addArg(numGroupTiles);
interactionGroupKernel->addArg((int) useNeighborList);
for (int i = 0; i < 5; i++)
interactionGroupKernel->addArg(); // Periodic box information will be set just before it is executed.
interactionGroupKernel->addArg((int) cc.getEnergyParamDerivNames().size());
for (auto& buffer : paramBuffers)
interactionGroupKernel->addArg(buffer.getArray());
for (auto& buffer : computedValueBuffers)
interactionGroupKernel->addArg(buffer.getArray());
for (auto& function : tabulatedFunctionArrays)
interactionGroupKernel->addArg(function);
if (needGlobalParams)
interactionGroupKernel->addArg(cc.getGlobalParamValues());
if (hasParamDerivs)
interactionGroupKernel->addArg(cc.getEnergyParamDerivBuffer());
if (useNeighborList) {
// Initialize kernels for building the interaction group neighbor list.
prepareNeighborListKernel->addArg(cc.getNonbondedUtilities().getRebuildNeighborList());
prepareNeighborListKernel->addArg(numGroupTiles);
buildNeighborListKernel->addArg(cc.getNonbondedUtilities().getRebuildNeighborList());
buildNeighborListKernel->addArg(numGroupTiles);
buildNeighborListKernel->addArg(cc.getPosq());
buildNeighborListKernel->addArg(interactionGroupData);
buildNeighborListKernel->addArg(filteredGroupData);
for (int i = 0; i < 5; i++)
buildNeighborListKernel->addArg(); // Periodic box information will be set just before it is executed.
}
}
int forceThreadBlockSize = max(32, cc.getNonbondedUtilities().getForceThreadBlockSize());
if (useNeighborList) {
// Rebuild the neighbor list, if necessary.
setPeriodicBoxArgs(cc, buildNeighborListKernel, 5);
prepareNeighborListKernel->execute(1, 1);
buildNeighborListKernel->execute(numGroupThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
}
setPeriodicBoxArgs(cc, interactionGroupKernel, 6);
interactionGroupKernel->execute(numGroupThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
}
return 0;
}
void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force, int firstParticle, int lastParticle) {
ContextSelector selector(cc);
int numParticles = force.getNumParticles();
if (numParticles != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed");
// Record the per-particle parameters.
if (firstParticle <= lastParticle) {
int numToSet = lastParticle-firstParticle+1;
int numParams = force.getNumPerParticleParameters();
vector > paramVector(numToSet, vector(numParams, 0));
vector parameters;
for (int i = 0; i < numToSet; i++) {
force.getParticleParameters(firstParticle+i, parameters);
paramVector[i].resize(parameters.size());
for (int j = 0; j < (int) parameters.size(); j++)
paramVector[i][j] = (float) parameters[j];
}
params->setParameterValuesSubset(firstParticle, paramVector);
}
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width;
vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, cc.getThreadPool().getNumThreads());
hasInitializedLongRangeCorrection = false;
*forceCopy = force;
longRangeCoefficientCache.clear();
longRangeCoefficientDerivsCache.clear();
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info, firstParticle <= lastParticle, false);
}