/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see . *
* -------------------------------------------------------------------------- */
#include "openmm/common/CommonCalcCustomManyParticleForceKernel.h"
#include "openmm/common/CommonKernelUtilities.h"
#include "openmm/common/ContextSelector.h"
#include "openmm/common/ExpressionUtilities.h"
#include "openmm/Context.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/CustomManyParticleForceImpl.h"
#include "CommonKernelSources.h"
#include "SimTKOpenMMRealType.h"
#include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h"
#include "lepton/Operation.h"
#include "lepton/Parser.h"
#include "lepton/ParsedExpression.h"
using namespace OpenMM;
using namespace std;
using namespace Lepton;
class CommonCalcCustomManyParticleForceKernel::ForceInfo : public ComputeForceInfo {
public:
ForceInfo(const CustomManyParticleForce& force) : force(force) {
}
bool areParticlesIdentical(int particle1, int particle2) {
thread_local static vector params1, params2;
int type1, type2;
force.getParticleParameters(particle1, params1, type1);
force.getParticleParameters(particle2, params2, type2);
if (type1 != type2)
return false;
for (int i = 0; i < (int) params1.size(); i++)
if (params1[i] != params2[i])
return false;
return true;
}
int getNumParticleGroups() {
return force.getNumExclusions();
}
void getParticlesInGroup(int index, vector& particles) {
int particle1, particle2;
force.getExclusionParticles(index, particle1, particle2);
particles.resize(2);
particles[0] = particle1;
particles[1] = particle2;
}
bool areGroupsIdentical(int group1, int group2) {
return true;
}
private:
const CustomManyParticleForce& force;
};
CommonCalcCustomManyParticleForceKernel::~CommonCalcCustomManyParticleForceKernel() {
ContextSelector selector(cc);
if (params != NULL)
delete params;
}
void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, const CustomManyParticleForce& force) {
ContextSelector selector(cc);
int numParticles = force.getNumParticles();
int particlesPerSet = force.getNumParticlesPerSet();
bool centralParticleMode = (force.getPermutationMode() == CustomManyParticleForce::UniqueCentralParticle);
nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
forceWorkgroupSize = 128;
findNeighborsWorkgroupSize = (cc.getSIMDWidth() >= 32 ? 128 : 32);
// Record parameter values.
params = new ComputeParameterSet(cc, force.getNumPerParticleParameters(), numParticles, "customManyParticleParameters");
vector > paramVector(numParticles);
for (int i = 0; i < numParticles; i++) {
vector parameters;
int type;
force.getParticleParameters(i, parameters, type);
paramVector[i].resize(parameters.size());
for (int j = 0; j < (int) parameters.size(); j++)
paramVector[i][j] = (float) parameters[j];
}
params->setParameterValues(paramVector);
info = new ForceInfo(force);
cc.addForce(info);
// Record the tabulated functions.
map functions;
vector > functionDefinitions;
vector functionList;
stringstream tableArgs;
tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].initialize(cc, f.size(), "TabulatedFunction");
tabulatedFunctionArrays[i].upload(f);
tableArgs << ", GLOBAL const float";
if (width > 1)
tableArgs << width;
tableArgs << "* RESTRICT " << arrayName;
}
// Record information about parameters.
vector > variables;
for (int i = 0; i < particlesPerSet; i++) {
string index = cc.intToString(i+1);
variables.push_back(makeVariable("x"+index, "pos"+index+".x"));
variables.push_back(makeVariable("y"+index, "pos"+index+".y"));
variables.push_back(makeVariable("z"+index, "pos"+index+".z"));
}
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
const string& name = force.getPerParticleParameterName(i);
for (int j = 0; j < particlesPerSet; j++) {
string index = cc.intToString(j+1);
variables.push_back(makeVariable(name+index, "((real) params"+params->getParameterSuffix(i, index)+")"));
}
}
needGlobalParams = (force.getNumGlobalParameters() > 0);
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i);
int index = cc.registerGlobalParam(name);
string value = "globals["+cc.intToString(index)+"]";
variables.push_back(makeVariable(name, value));
}
// Build data structures for type filters.
vector particleTypesVec;
vector orderIndexVec;
vector > particleOrderVec;
int numTypes;
CustomManyParticleForceImpl::buildFilterArrays(force, numTypes, particleTypesVec, orderIndexVec, particleOrderVec);
bool hasTypeFilters = (particleOrderVec.size() > 1);
if (hasTypeFilters) {
particleTypes.initialize(cc, particleTypesVec.size(), "customManyParticleTypes");
orderIndex.initialize(cc, orderIndexVec.size(), "customManyParticleOrderIndex");
particleOrder.initialize(cc, particleOrderVec.size()*particlesPerSet, "customManyParticleOrder");
particleTypes.upload(particleTypesVec);
orderIndex.upload(orderIndexVec);
vector flattenedOrder(particleOrder.getSize());
for (int i = 0; i < (int) particleOrderVec.size(); i++)
for (int j = 0; j < particlesPerSet; j++)
flattenedOrder[i*particlesPerSet+j] = particleOrderVec[i][j];
particleOrder.upload(flattenedOrder);
}
// Build data structures for exclusions.
if (force.getNumExclusions() > 0) {
vector > particleExclusions(numParticles);
for (int i = 0; i < force.getNumExclusions(); i++) {
int p1, p2;
force.getExclusionParticles(i, p1, p2);
particleExclusions[p1].push_back(p2);
particleExclusions[p2].push_back(p1);
}
vector exclusionsVec;
vector exclusionStartIndexVec(numParticles+1);
exclusionStartIndexVec[0] = 0;
for (int i = 0; i < numParticles; i++) {
sort(particleExclusions[i].begin(), particleExclusions[i].end());
exclusionsVec.insert(exclusionsVec.end(), particleExclusions[i].begin(), particleExclusions[i].end());
exclusionStartIndexVec[i+1] = exclusionsVec.size();
}
exclusions.initialize(cc, exclusionsVec.size(), "customManyParticleExclusions");
exclusionStartIndex.initialize(cc, exclusionStartIndexVec.size(), "customManyParticleExclusionStart");
exclusions.upload(exclusionsVec);
exclusionStartIndex.upload(exclusionStartIndexVec);
}
// Build data structures for the neighbor list.
int numAtomBlocks = cc.getPaddedNumAtoms()/32;
if (nonbondedMethod != NoCutoff) {
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
blockCenter.initialize(cc, numAtomBlocks, 4*elementSize, "blockCenter");
blockBoundingBox.initialize(cc, numAtomBlocks, 4*elementSize, "blockBoundingBox");
numNeighborPairs.initialize(cc, 1, "customManyParticleNumNeighborPairs");
neighborStartIndex.initialize(cc, numParticles+1, "customManyParticleNeighborStartIndex");
numNeighborsForAtom.initialize(cc, numParticles, "customManyParticleNumNeighborsForAtom");
// Select a size for the array that holds the neighbor list. We have to make a fairly
// arbitrary guess, but if this turns out to be too small we'll increase it later.
maxNeighborPairs = 150*numParticles;
neighborPairs.initialize(cc, maxNeighborPairs, "customManyParticleNeighborPairs");
neighbors.initialize(cc, maxNeighborPairs, "customManyParticleNeighbors");
}
// Generate the kernel.
Lepton::ParsedExpression energyExpression = CustomManyParticleForceImpl::prepareExpression(force, functions);
map forceExpressions;
stringstream compute;
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& parameter = params->getParameterInfos()[i];
compute< forceNames;
for (int i = 0; i < particlesPerSet; i++) {
string istr = cc.intToString(i+1);
string forceName = "force"+istr;
forceNames.push_back(forceName);
compute<<"real3 "<getParameterInfos().size(); j++)
loadData<getParameterInfos()[j].getType()<<" params"<<(j+1)<<(i+1)<<" = global_params"<<(j+1)<<"[atom"<<(i+1)<<"];\n";
}
if (centralParticleMode) {
for (int i = 1; i < particlesPerSet; i++) {
if (i > 1)
isValidCombination<<" && p"<<(i+1)<<">p"< 2)
isValidCombination<<" && ";
isValidCombination<<"a"<<(i+1)<<">a"< 1)
numCombinations<<"*";
numCombinations<<"numNeighbors";
if (centralParticleMode)
atomsForCombination<<"int a"<<(i+1)<<" = tempIndex%numNeighbors;\n";
else
atomsForCombination<<"int a"<<(i+1)<<" = 1+tempIndex%numNeighbors;\n";
if (i < particlesPerSet-1)
atomsForCombination<<"tempIndex /= numNeighbors;\n";
}
if (particlesPerSet > 2) {
if (centralParticleMode)
atomsForCombination<<"a2 = (a3%2 == 0 ? a2 : numNeighbors-a2-1);\n";
else
atomsForCombination<<"a2 = (a3%2 == 0 ? a2 : numNeighbors-a2+1);\n";
}
for (int i = 1; i < particlesPerSet; i++) {
if (nonbondedMethod == NoCutoff) {
if (centralParticleMode)
atomsForCombination<<"int p"<<(i+1)<<" = a"<<(i+1)<<";\n";
else
atomsForCombination<<"int p"<<(i+1)<<" = p1+a"<<(i+1)<<";\n";
}
else {
if (centralParticleMode)
atomsForCombination<<"int p"<<(i+1)<<" = neighbors[firstNeighbor+a"<<(i+1)<<"];\n";
else
atomsForCombination<<"int p"<<(i+1)<<" = neighbors[firstNeighbor-1+a"<<(i+1)<<"];\n";
}
}
if (nonbondedMethod != NoCutoff) {
for (int i = 1; i < particlesPerSet; i++)
verifyCutoff<<"real3 pos"<<(i+1)<<" = trimTo3(posq[p"<<(i+1)<<"]);\n";
if (!centralParticleMode) {
for (int i = 1; i < particlesPerSet; i++) {
for (int j = i+1; j < particlesPerSet; j++)
verifyCutoff<<"includeInteraction &= (delta(pos"<<(i+1)<<", pos"<<(j+1)<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ).w < CUTOFF_SQUARED);\n";
}
}
}
if (force.getNumExclusions() > 0) {
int startCheckFrom = (nonbondedMethod == NoCutoff ? 0 : 1);
for (int i = startCheckFrom; i < particlesPerSet; i++)
for (int j = i+1; j < particlesPerSet; j++)
verifyExclusions<<"includeInteraction &= !isInteractionExcluded(p"<<(i+1)<<", p"<<(j+1)<<", exclusions, exclusionStartIndex);\n";
}
string computeTypeIndex = "particleTypes[p"+cc.intToString(particlesPerSet)+"]";
for (int i = particlesPerSet-2; i >= 0; i--)
computeTypeIndex = "particleTypes[p"+cc.intToString(i+1)+"]+"+cc.intToString(numTypes)+"*("+computeTypeIndex+")";
// Create replacements for extra arguments.
stringstream extraArgs;
if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const real* globals";
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& parameter = params->getParameterInfos()[i];
extraArgs<<", GLOBAL const "< replacements;
replacements["COMPUTE_INTERACTION"] = compute.str();
replacements["NUM_CANDIDATE_COMBINATIONS"] = numCombinations.str();
replacements["FIND_ATOMS_FOR_COMBINATION_INDEX"] = atomsForCombination.str();
replacements["IS_VALID_COMBINATION"] = isValidCombination.str();
replacements["VERIFY_CUTOFF"] = verifyCutoff.str();
replacements["VERIFY_EXCLUSIONS"] = verifyExclusions.str();
replacements["PERMUTE_ATOMS"] = permute.str();
replacements["LOAD_PARTICLE_DATA"] = loadData.str();
replacements["COMPUTE_TYPE_INDEX"] = computeTypeIndex;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
map defines;
if (nonbondedMethod != NoCutoff)
defines["USE_CUTOFF"] = "1";
if (nonbondedMethod == CutoffPeriodic)
defines["USE_PERIODIC"] = "1";
if (centralParticleMode)
defines["USE_CENTRAL_PARTICLE"] = "1";
if (hasTypeFilters)
defines["USE_FILTERS"] = "1";
if (force.getNumExclusions() > 0)
defines["USE_EXCLUSIONS"] = "1";
defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
defines["M_PI"] = cc.doubleToString(M_PI);
defines["CUTOFF_SQUARED"] = cc.doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
defines["TILE_SIZE"] = cc.intToString(32);
defines["NUM_BLOCKS"] = cc.intToString(numAtomBlocks);
defines["FIND_NEIGHBORS_WORKGROUP_SIZE"] = cc.intToString(findNeighborsWorkgroupSize);
ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::pointFunctions+CommonKernelSources::customManyParticle, replacements), defines);
forceKernel = program->createKernel("computeInteraction");
blockBoundsKernel = program->createKernel("findBlockBounds");
neighborsKernel = program->createKernel("findNeighbors");
startIndicesKernel = program->createKernel("computeNeighborStartIndices");
copyPairsKernel = program->createKernel("copyPairsToNeighborList");
event = cc.createEvent();
}
double CommonCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
if (!hasInitializedKernel) {
hasInitializedKernel = true;
// Set arguments for the force kernel.
forceKernel->addArg(cc.getLongForceBuffer());
forceKernel->addArg(cc.getEnergyBuffer());
forceKernel->addArg(cc.getPosq());
for (int i = 0; i < 5; i++)
forceKernel->addArg();
setPeriodicBoxArgs(cc, forceKernel, 3);
if (nonbondedMethod != NoCutoff) {
forceKernel->addArg(neighbors);
forceKernel->addArg(neighborStartIndex);
}
if (particleTypes.isInitialized()) {
forceKernel->addArg(particleTypes);
forceKernel->addArg(orderIndex);
forceKernel->addArg(particleOrder);
}
if (exclusions.isInitialized()) {
forceKernel->addArg(exclusions);
forceKernel->addArg(exclusionStartIndex);
}
if (needGlobalParams)
forceKernel->addArg(cc.getGlobalParamValues());
for (auto& parameter : params->getParameterInfos())
forceKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctionArrays)
forceKernel->addArg(function);
if (nonbondedMethod != NoCutoff) {
// Set arguments for the block bounds kernel.
for (int i = 0; i < 5; i++)
blockBoundsKernel->addArg(); // Periodic box information will be set just before it is executed.
blockBoundsKernel->addArg(cc.getPosq());
blockBoundsKernel->addArg(blockCenter);
blockBoundsKernel->addArg(blockBoundingBox);
blockBoundsKernel->addArg(numNeighborPairs);
// Set arguments for the neighbor list kernel.
for (int i = 0; i < 5; i++)
neighborsKernel->addArg(); // Periodic box information will be set just before it is executed.
neighborsKernel->addArg(cc.getPosq());
neighborsKernel->addArg(blockCenter);
neighborsKernel->addArg(blockBoundingBox);
neighborsKernel->addArg(neighborPairs);
neighborsKernel->addArg(numNeighborPairs);
neighborsKernel->addArg(numNeighborsForAtom);
neighborsKernel->addArg(maxNeighborPairs);
if (exclusions.isInitialized()) {
neighborsKernel->addArg(exclusions);
neighborsKernel->addArg(exclusionStartIndex);
}
// Set arguments for the kernel to find neighbor list start indices.
startIndicesKernel->addArg(numNeighborsForAtom);
startIndicesKernel->addArg(neighborStartIndex);
startIndicesKernel->addArg(numNeighborPairs);
startIndicesKernel->addArg(maxNeighborPairs);
// Set arguments for the kernel to assemble the final neighbor list.
copyPairsKernel->addArg(neighborPairs);
copyPairsKernel->addArg(neighbors);
copyPairsKernel->addArg(numNeighborPairs);
copyPairsKernel->addArg(maxNeighborPairs);
copyPairsKernel->addArg(numNeighborsForAtom);
copyPairsKernel->addArg(neighborStartIndex);
}
}
while (true) {
int* numPairs = (int*) cc.getPinnedBuffer();
if (nonbondedMethod != NoCutoff) {
setPeriodicBoxArgs(cc, forceKernel, 3);
setPeriodicBoxArgs(cc, blockBoundsKernel, 0);
setPeriodicBoxArgs(cc, neighborsKernel, 0);
blockBoundsKernel->execute(cc.getPaddedNumAtoms()/32);
neighborsKernel->execute(cc.getNumAtoms(), findNeighborsWorkgroupSize);
// We need to make sure there was enough memory for the neighbor list. Download the
// information asynchronously so kernels can be running at the same time.
numNeighborPairs.download(numPairs, false);
event->enqueue();
startIndicesKernel->execute(256, 256);
copyPairsKernel->execute(maxNeighborPairs);
}
int maxThreads = min(cc.getNumAtoms()*forceWorkgroupSize, (int) cc.getEnergyBuffer().getSize());
forceKernel->execute(maxThreads, forceWorkgroupSize);
if (nonbondedMethod != NoCutoff) {
// Make sure there was enough memory for the neighbor list.
event->wait();
if (*numPairs > maxNeighborPairs) {
// Resize the arrays and run the calculation again.
maxNeighborPairs = (int) (1.1*(*numPairs));
neighborPairs.resize(maxNeighborPairs);
neighbors.resize(maxNeighborPairs);
neighborsKernel->setArg(11, maxNeighborPairs);
startIndicesKernel->setArg(3, maxNeighborPairs);
copyPairsKernel->setArg(3, maxNeighborPairs);
continue;
}
}
break;
}
return 0.0;
}
void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl& context, const CustomManyParticleForce& force) {
ContextSelector selector(cc);
int numParticles = force.getNumParticles();
if (numParticles != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed");
// Record the per-particle parameters.
vector > paramVector(numParticles);
vector parameters;
int type;
for (int i = 0; i < numParticles; i++) {
force.getParticleParameters(i, parameters, type);
paramVector[i].resize(parameters.size());
for (int j = 0; j < (int) parameters.size(); j++)
paramVector[i][j] = (float) parameters[j];
}
params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width;
vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info);
}