/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see . *
* -------------------------------------------------------------------------- */
#include "openmm/common/CommonCalcCustomGBForceKernel.h"
#include "openmm/common/CommonKernelUtilities.h"
#include "openmm/common/ContextSelector.h"
#include "openmm/common/ExpressionUtilities.h"
#include "openmm/Context.h"
#include "openmm/internal/ContextImpl.h"
#include "CommonKernelSources.h"
#include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h"
#include "lepton/Operation.h"
#include "lepton/Parser.h"
#include "lepton/ParsedExpression.h"
using namespace OpenMM;
using namespace std;
using namespace Lepton;
class CommonCalcCustomGBForceKernel::ForceInfo : public ComputeForceInfo {
public:
ForceInfo(const CustomGBForce& force) : force(force) {
}
bool areParticlesIdentical(int particle1, int particle2) {
thread_local static vector params1, params2;
force.getParticleParameters(particle1, params1);
force.getParticleParameters(particle2, params2);
for (int i = 0; i < (int) params1.size(); i++)
if (params1[i] != params2[i])
return false;
return true;
}
int getNumParticleGroups() {
return force.getNumExclusions();
}
void getParticlesInGroup(int index, vector& particles) {
int particle1, particle2;
force.getExclusionParticles(index, particle1, particle2);
particles.resize(2);
particles[0] = particle1;
particles[1] = particle2;
}
bool areGroupsIdentical(int group1, int group2) {
return true;
}
private:
const CustomGBForce& force;
};
CommonCalcCustomGBForceKernel::~CommonCalcCustomGBForceKernel() {
ContextSelector selector(cc);
if (params != NULL)
delete params;
if (computedValues != NULL)
delete computedValues;
if (energyDerivs != NULL)
delete energyDerivs;
if (energyDerivChain != NULL)
delete energyDerivChain;
for (auto d : dValuedParam)
delete d;
}
void CommonCalcCustomGBForceKernel::initialize(const System& system, const CustomGBForce& force) {
ContextSelector selector(cc);
if (cc.getNumContexts() > 1)
throw OpenMMException("CustomGBForce does not support using multiple devices");
NonbondedUtilities& nb = cc.getNonbondedUtilities();
cutoff = force.getCutoffDistance();
needGlobalParams = (force.getNumGlobalParameters() > 0);
bool useExclusionsForValue = false;
numComputedValues = force.getNumComputedValues();
vector computedValueNames(numComputedValues);
vector computedValueExpressions(numComputedValues);
if (numComputedValues > 0) {
CustomGBForce::ComputationType type;
force.getComputedValueParameters(0, computedValueNames[0], computedValueExpressions[0], type);
if (type == CustomGBForce::SingleParticle)
throw OpenMMException("The first computed value for a CustomGBForce must be of type ParticlePair or ParticlePairNoExclusions.");
useExclusionsForValue = (type == CustomGBForce::ParticlePair);
for (int i = 1; i < numComputedValues; i++) {
force.getComputedValueParameters(i, computedValueNames[i], computedValueExpressions[i], type);
if (type != CustomGBForce::SingleParticle)
throw OpenMMException("A CustomGBForce may only have one computed value of type ParticlePair or ParticlePairNoExclusions.");
}
}
int forceIndex;
for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex)
;
string prefix = "custom"+cc.intToString(forceIndex)+"_";
// Record parameters and exclusions.
int numParticles = force.getNumParticles();
int paddedNumParticles = cc.getPaddedNumAtoms();
int numParams = force.getNumPerParticleParameters();
params = new ComputeParameterSet(cc, force.getNumPerParticleParameters(), paddedNumParticles, "customGBParameters", true);
computedValues = new ComputeParameterSet(cc, numComputedValues, paddedNumParticles, "customGBComputedValues", true, cc.getUseDoublePrecision());
vector > paramVector(paddedNumParticles, vector(numParams, 0));
vector > exclusionList(numParticles);
for (int i = 0; i < numParticles; i++) {
vector parameters;
force.getParticleParameters(i, parameters);
for (int j = 0; j < (int) parameters.size(); j++)
paramVector[i][j] = (float) parameters[j];
exclusionList[i].push_back(i);
}
for (int i = 0; i < force.getNumExclusions(); i++) {
int particle1, particle2;
force.getExclusionParticles(i, particle1, particle2);
exclusionList[particle1].push_back(particle2);
exclusionList[particle2].push_back(particle1);
}
params->setParameterValues(paramVector);
// Record the tabulated functions.
map functions;
vector > functionDefinitions;
vector functionList;
stringstream tableArgs;
tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].initialize(cc, f.size(), "TabulatedFunction");
tabulatedFunctionArrays[i].upload(f);
nb.addArgument(ComputeParameterInfo(tabulatedFunctionArrays[i], arrayName, "float", width));
tableArgs << ", GLOBAL const float";
if (width > 1)
tableArgs << width;
tableArgs << "* RESTRICT " << arrayName;
}
// Record derivatives of expressions needed for the chain rule terms.
vector > valueGradientExpressions(numComputedValues);
vector > valueDerivExpressions(numComputedValues);
vector > valueParamDerivExpressions(numComputedValues);
needParameterGradient = false;
for (int i = 0; i < numComputedValues; i++) {
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
if (i > 0) {
valueGradientExpressions[i].push_back(ex.differentiate("x").optimize());
valueGradientExpressions[i].push_back(ex.differentiate("y").optimize());
valueGradientExpressions[i].push_back(ex.differentiate("z").optimize());
if (!isZeroExpression(valueGradientExpressions[i][0]) || !isZeroExpression(valueGradientExpressions[i][1]) || !isZeroExpression(valueGradientExpressions[i][2]))
needParameterGradient = true;
for (int j = 0; j < i; j++)
valueDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
}
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
valueParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).optimize());
}
vector > energyDerivExpressions(force.getNumEnergyTerms());
vector > energyParamDerivExpressions(force.getNumEnergyTerms());
vector needChainForValue(numComputedValues, false);
for (int i = 0; i < force.getNumEnergyTerms(); i++) {
string expression;
CustomGBForce::ComputationType type;
force.getEnergyTermParameters(i, expression, type);
Lepton::ParsedExpression ex = Lepton::Parser::parse(expression, functions).optimize();
for (int j = 0; j < numComputedValues; j++) {
if (type == CustomGBForce::SingleParticle) {
energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
if (!isZeroExpression(energyDerivExpressions[i].back()))
needChainForValue[j] = true;
}
else {
energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"1").optimize());
if (!isZeroExpression(energyDerivExpressions[i].back()))
needChainForValue[j] = true;
energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"2").optimize());
if (!isZeroExpression(energyDerivExpressions[i].back()))
needChainForValue[j] = true;
}
}
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
energyParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).optimize());
}
bool deviceIsCpu = cc.getIsCPU();
valueBuffers.initialize(cc, cc.getPaddedNumAtoms(), "customGBValueBuffers");
longEnergyDerivs.initialize(cc, numComputedValues*cc.getPaddedNumAtoms(), "customGBLongEnergyDerivatives");
energyDerivs = new ComputeParameterSet(cc, numComputedValues, cc.getPaddedNumAtoms(), "customGBEnergyDerivatives", true);
cc.addAutoclearBuffer(valueBuffers);
energyDerivChain = new ComputeParameterSet(cc, numComputedValues, cc.getPaddedNumAtoms(), "customGBEnergyDerivativeChain", true);
needEnergyParamDerivs = (force.getNumEnergyParameterDerivatives() > 0);
dValue0dParam.resize(force.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
dValuedParam.push_back(new ComputeParameterSet(cc, numComputedValues, cc.getPaddedNumAtoms(), "dValuedParam", true, cc.getUseDoublePrecision()));
dValue0dParam[i].initialize(cc, cc.getPaddedNumAtoms(), "dValue0dParam");
cc.addAutoclearBuffer(dValue0dParam[i]);
string name = force.getEnergyParameterDerivativeName(i);
cc.addEnergyParameterDerivative(name);
}
// Create the kernels.
bool useCutoff = (force.getNonbondedMethod() != CustomGBForce::NoCutoff);
bool usePeriodic = (force.getNonbondedMethod() != CustomGBForce::NoCutoff && force.getNonbondedMethod() != CustomGBForce::CutoffNonPeriodic);
int numAtomBlocks = cc.getPaddedNumAtoms()/32;
{
// Create the N2 value kernel.
vector > variables;
map rename;
ExpressionTreeNode rnode(new Operation::Variable("r"));
variables.push_back(make_pair(rnode, "r"));
variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
const string& name = force.getPerParticleParameterName(i);
variables.push_back(makeVariable(name+"1", "((real) params"+params->getParameterSuffix(i, "1)")));
variables.push_back(makeVariable(name+"2", "((real) params"+params->getParameterSuffix(i, "2)")));
rename[name+"1"] = name+"2";
rename[name+"2"] = name+"1";
}
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i);
int index = cc.registerGlobalParam(name);
string value = "globals["+cc.intToString(index)+"]";
variables.push_back(makeVariable(name, value));
}
map n2ValueExpressions;
stringstream n2ValueSource;
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueExpressions["tempValue1 = "] = ex;
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
for (int i = 0; i < valueParamDerivExpressions[0].size(); i++) {
string variableBase = "temp_dValue0dParam"+cc.intToString(i+1);
if (!isZeroExpression(valueParamDerivExpressions[0][i])) {
n2ValueExpressions[variableBase+"_1 = "] = valueParamDerivExpressions[0][i];
n2ValueExpressions[variableBase+"_2 = "] = valueParamDerivExpressions[0][i].renameVariables(rename);
}
}
n2ValueSource << cc.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp");
map replacements;
string n2ValueStr = n2ValueSource.str();
replacements["COMPUTE_VALUE"] = n2ValueStr;
stringstream extraArgs, atomParams, loadLocal1, loadLocal2, load1, load2, tempDerivs1, tempDerivs2, storeDeriv1, storeDeriv2;
if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const real* globals";
pairValueUsesParam.resize(params->getParameterInfos().size(), false);
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1);
if (n2ValueStr.find(paramName+"1") != n2ValueStr.npos || n2ValueStr.find(paramName+"2") != n2ValueStr.npos) {
extraArgs << ", GLOBAL const " << buffer.getType() << "* RESTRICT global_" << paramName;
atomParams << "LOCAL " << buffer.getType() << " local_" << paramName << "[LOCAL_BUFFER_SIZE];\n";
loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n";
loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n";
load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n";
load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n";
pairValueUsesParam[i] = true;
}
}
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string derivName = "dValue0dParam"+cc.intToString(i+1);
extraArgs << ", GLOBAL mm_ulong* RESTRICT global_" << derivName;
atomParams << "LOCAL real local_" << derivName << "[LOCAL_BUFFER_SIZE];\n";
loadLocal2 << "local_" << derivName << "[localAtomIndex] = 0;\n";
load1 << "real " << derivName << " = 0;\n";
if (!isZeroExpression(valueParamDerivExpressions[0][i])) {
load2 << "real temp_" << derivName << "_1 = 0;\n";
load2 << "real temp_" << derivName << "_2 = 0;\n";
tempDerivs1 << derivName << " += temp_" << derivName << "_1;\n";
if (deviceIsCpu)
tempDerivs2 << "local_" << derivName << "[j] += temp_" << derivName << "_2;\n";
else
tempDerivs2 << "local_" << derivName << "[tbx+tj] += temp_" << derivName << "_2;\n";
storeDeriv1 << "ATOMIC_ADD(&global_" << derivName << "[offset1], (mm_ulong) realToFixedPoint(" << derivName << "));\n";
if (deviceIsCpu)
storeDeriv2 << "ATOMIC_ADD(&global_" << derivName << "[offset2], (mm_ulong) realToFixedPoint(local_" << derivName << "[tgx]));\n";
else
storeDeriv2 << "ATOMIC_ADD(&global_" << derivName << "[offset2], (mm_ulong) realToFixedPoint(local_" << derivName << "[LOCAL_ID]));\n";
}
}
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["ATOM_PARAMETER_DATA"] = atomParams.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
replacements["ADD_TEMP_DERIVS1"] = tempDerivs1.str();
replacements["ADD_TEMP_DERIVS2"] = tempDerivs2.str();
replacements["STORE_PARAM_DERIVS1"] = storeDeriv1.str();
replacements["STORE_PARAM_DERIVS2"] = storeDeriv2.str();
if (useCutoff)
pairValueDefines["USE_CUTOFF"] = "1";
if (usePeriodic)
pairValueDefines["USE_PERIODIC"] = "1";
if (useExclusionsForValue)
pairValueDefines["USE_EXCLUSIONS"] = "1";
pairValueDefines["LOCAL_BUFFER_SIZE"] = cc.intToString(deviceIsCpu ? 32 : nb.getForceThreadBlockSize());
pairValueDefines["CUTOFF_SQUARED"] = cc.doubleToString(cutoff*cutoff);
pairValueDefines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
pairValueDefines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
pairValueDefines["NUM_BLOCKS"] = cc.intToString(numAtomBlocks);
pairValueDefines["TILE_SIZE"] = "32";
string file;
if (deviceIsCpu)
file = CommonKernelSources::customGBValueN2_cpu;
else
file = CommonKernelSources::customGBValueN2;
pairValueSrc = cc.replaceStrings(file, replacements);
}
{
// Create the kernel to reduce the N2 value and calculate other values.
stringstream reductionSource, extraArgs, deriv0;
if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const real* globals";
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1);
extraArgs << ", GLOBAL const " << buffer.getType() << "* RESTRICT " << paramName;
}
for (int i = 0; i < (int) computedValues->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = computedValues->getParameterInfos()[i];
string valueName = "values"+cc.intToString(i+1);
extraArgs << ", GLOBAL " << buffer.getType() << "* RESTRICT global_" << valueName;
reductionSource << buffer.getType() << " local_" << valueName << ";\n";
}
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string variableName = "dValuedParam_0_"+cc.intToString(i);
extraArgs << ", GLOBAL const mm_long* RESTRICT dValue0dParam" << i;
deriv0 << "real " << variableName << " = RECIP((real) 0x100000000)*dValue0dParam" << i << "[index];\n";
for (int j = 0; j < dValuedParam[i]->getParameterInfos().size(); j++)
extraArgs << ", GLOBAL real* RESTRICT global_dValuedParam_" << j << "_" << i;
deriv0 << "global_dValuedParam_0_" << i << "[index] = dValuedParam_0_" << i << ";\n";
}
reductionSource << "local_values" << computedValues->getParameterSuffix(0) << " = sum;\n";
map variables;
variables["x"] = "pos.x";
variables["y"] = "pos.y";
variables["z"] = "pos.z";
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(index)+"]";
}
for (int i = 1; i < numComputedValues; i++) {
variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
map valueExpressions;
valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
reductionSource << cc.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cc.intToString(i)+"_temp");
}
for (int i = 0; i < (int) computedValues->getParameterInfos().size(); i++) {
string valueName = "values"+cc.intToString(i+1);
reductionSource << "global_" << valueName << "[index] = local_" << valueName << ";\n";
}
if (needEnergyParamDerivs) {
map derivExpressions;
for (int i = 1; i < numComputedValues; i++) {
for (int j = 0; j < valueParamDerivExpressions[i].size(); j++)
derivExpressions["real dValuedParam_"+cc.intToString(i)+"_"+cc.intToString(j)+" = "] = valueParamDerivExpressions[i][j];
for (int j = 0; j < i; j++)
derivExpressions["real dVdV_"+cc.intToString(i)+"_"+cc.intToString(j)+" = "] = valueDerivExpressions[i][j];
}
reductionSource << cc.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "derivChain_temp");
for (int i = 1; i < numComputedValues; i++) {
for (int j = 0; j < i; j++)
for (int k = 0; k < valueParamDerivExpressions[i].size(); k++)
reductionSource << "dValuedParam_" << i << "_" << k << " += dVdV_" << i << "_" << j << "*dValuedParam_" << j <<"_" << k << ";\n";
for (int j = 0; j < valueParamDerivExpressions[i].size(); j++)
reductionSource << "global_dValuedParam_" << i << "_" << j << "[index] = dValuedParam_" << i << "_" << j << ";\n";
}
}
map replacements;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["REDUCE_PARAM0_DERIV"] = deriv0.str();
replacements["COMPUTE_VALUES"] = reductionSource.str();
map defines;
defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::customGBValuePerParticle, replacements), defines);
perParticleValueKernel = program->createKernel("computePerParticleValues");
}
{
// Create the N2 energy kernel.
vector > variables;
ExpressionTreeNode rnode(new Operation::Variable("r"));
variables.push_back(make_pair(rnode, "r"));
variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
const string& name = force.getPerParticleParameterName(i);
variables.push_back(makeVariable(name+"1", "((real) params"+params->getParameterSuffix(i, "1)")));
variables.push_back(makeVariable(name+"2", "((real) params"+params->getParameterSuffix(i, "2)")));
}
for (int i = 0; i < numComputedValues; i++) {
variables.push_back(makeVariable(computedValueNames[i]+"1", "values"+computedValues->getParameterSuffix(i, "1")));
variables.push_back(makeVariable(computedValueNames[i]+"2", "values"+computedValues->getParameterSuffix(i, "2")));
}
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
variables.push_back(makeVariable(force.getGlobalParameterName(i), "globals["+cc.intToString(index)+"]"));
}
stringstream n2EnergySource;
bool anyExclusions = (force.getNumExclusions() > 0);
for (int i = 0; i < force.getNumEnergyTerms(); i++) {
string expression;
CustomGBForce::ComputationType type;
force.getEnergyTermParameters(i, expression, type);
if (type == CustomGBForce::SingleParticle)
continue;
bool exclude = (anyExclusions && type == CustomGBForce::ParticlePair);
map n2EnergyExpressions;
n2EnergyExpressions["tempEnergy += "] = Lepton::Parser::parse(expression, functions).optimize();
n2EnergyExpressions["dEdR += "] = Lepton::Parser::parse(expression, functions).differentiate("r").optimize();
for (int j = 0; j < numComputedValues; j++) {
if (needChainForValue[j]) {
string index = cc.intToString(j+1);
n2EnergyExpressions["/*"+cc.intToString(i+1)+"*/ deriv"+index+"_1 += "] = energyDerivExpressions[i][2*j];
n2EnergyExpressions["/*"+cc.intToString(i+1)+"*/ deriv"+index+"_2 += "] = energyDerivExpressions[i][2*j+1];
}
}
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
n2EnergyExpressions["energyParamDeriv"+cc.intToString(j)+" += interactionScale*"] = energyParamDerivExpressions[i][j];
if (exclude)
n2EnergySource << "if (!isExcluded) {\n";
n2EnergySource << cc.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp");
if (exclude)
n2EnergySource << "}\n";
}
map replacements;
string n2EnergyStr = n2EnergySource.str();
replacements["COMPUTE_INTERACTION"] = n2EnergyStr;
stringstream extraArgs, atomParams, loadLocal1, loadLocal2, clearLocal, load1, load2, declare1, recordDeriv, storeDerivs1, storeDerivs2, initParamDerivs, saveParamDerivs;
if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const real* globals";
pairEnergyUsesParam.resize(params->getParameterInfos().size(), false);
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1);
if (n2EnergyStr.find(paramName+"1") != n2EnergyStr.npos || n2EnergyStr.find(paramName+"2") != n2EnergyStr.npos) {
extraArgs << ", GLOBAL const " << buffer.getType() << "* RESTRICT global_" << paramName;
atomParams << "LOCAL " << buffer.getType() << " local_" << paramName << "[LOCAL_BUFFER_SIZE];\n";
loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n";
loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n";
load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n";
load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n";
pairEnergyUsesParam[i] = true;
}
}
pairEnergyUsesValue.resize(computedValues->getParameterInfos().size(), false);
for (int i = 0; i < (int) computedValues->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = computedValues->getParameterInfos()[i];
string valueName = "values"+cc.intToString(i+1);
if (n2EnergyStr.find(valueName+"1") != n2EnergyStr.npos || n2EnergyStr.find(valueName+"2") != n2EnergyStr.npos) {
extraArgs << ", GLOBAL const " << buffer.getType() << "* RESTRICT global_" << valueName;
atomParams << "LOCAL " << buffer.getType() << " local_" << valueName << "[LOCAL_BUFFER_SIZE];\n";
loadLocal1 << "local_" << valueName << "[localAtomIndex] = " << valueName << "1;\n";
loadLocal2 << "local_" << valueName << "[localAtomIndex] = global_" << valueName << "[j];\n";
load1 << buffer.getType() << " " << valueName << "1 = global_" << valueName << "[atom1];\n";
load2 << buffer.getType() << " " << valueName << "2 = local_" << valueName << "[atom2];\n";
pairEnergyUsesValue[i] = true;
}
}
extraArgs << ", GLOBAL mm_ulong* RESTRICT derivBuffers";
for (int i = 0; i < numComputedValues; i++) {
string index = cc.intToString(i+1);
atomParams << "LOCAL real local_deriv" << index << "[LOCAL_BUFFER_SIZE];\n";
clearLocal << "local_deriv" << index << "[localAtomIndex] = 0.0f;\n";
declare1 << "real deriv" << index << "_1 = 0;\n";
load2 << "real deriv" << index << "_2 = 0;\n";
recordDeriv << "local_deriv" << index << "[atom2] += deriv" << index << "_2;\n";
storeDerivs1 << "STORE_DERIVATIVE_1(" << index << ")\n";
storeDerivs2 << "STORE_DERIVATIVE_2(" << index << ")\n";
}
if (needEnergyParamDerivs) {
extraArgs << ", GLOBAL mixed* RESTRICT energyParamDerivs";
const vector& allParamDerivNames = cc.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
saveParamDerivs << "energyParamDerivs[GLOBAL_ID*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
}
}
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["ATOM_PARAMETER_DATA"] = atomParams.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
replacements["CLEAR_LOCAL_DERIVATIVES"] = clearLocal.str();
replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
replacements["DECLARE_ATOM1_DERIVATIVES"] = declare1.str();
replacements["RECORD_DERIVATIVE_2"] = recordDeriv.str();
replacements["STORE_DERIVATIVES_1"] = storeDerivs1.str();
replacements["STORE_DERIVATIVES_2"] = storeDerivs2.str();
replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
if (useCutoff)
pairEnergyDefines["USE_CUTOFF"] = "1";
if (usePeriodic)
pairEnergyDefines["USE_PERIODIC"] = "1";
if (anyExclusions)
pairEnergyDefines["USE_EXCLUSIONS"] = "1";
pairEnergyDefines["LOCAL_BUFFER_SIZE"] = cc.intToString(deviceIsCpu ? 32 : nb.getForceThreadBlockSize());
pairEnergyDefines["CUTOFF_SQUARED"] = cc.doubleToString(cutoff*cutoff);
pairEnergyDefines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
pairEnergyDefines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
pairEnergyDefines["NUM_BLOCKS"] = cc.intToString(numAtomBlocks);
pairEnergyDefines["TILE_SIZE"] = "32";
string file;
if (deviceIsCpu)
file = CommonKernelSources::customGBEnergyN2_cpu;
else
file = CommonKernelSources::customGBEnergyN2;
pairEnergySrc = cc.replaceStrings(file, replacements);
}
{
// Create the kernel to reduce the derivatives and calculate per-particle energy terms.
stringstream compute, extraArgs, reduce, initParamDerivs, saveParamDerivs;
if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const real* globals";
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1);
extraArgs << ", GLOBAL const " << buffer.getType() << "* RESTRICT " << paramName;
}
for (int i = 0; i < (int) computedValues->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = computedValues->getParameterInfos()[i];
string valueName = "values"+cc.intToString(i+1);
extraArgs << ", GLOBAL const " << buffer.getType() << "* RESTRICT " << valueName;
}
for (int i = 0; i < (int) energyDerivs->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = energyDerivs->getParameterInfos()[i];
string index = cc.intToString(i+1);
extraArgs << ", GLOBAL " << buffer.getType() << "* RESTRICT derivBuffers" << index;
compute << buffer.getType() << " deriv" << index << " = derivBuffers" << index << "[index];\n";
}
for (int i = 0; i < (int) energyDerivChain->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = energyDerivChain->getParameterInfos()[i];
string index = cc.intToString(i+1);
extraArgs << ", GLOBAL " << buffer.getType() << "* RESTRICT derivChain" << index;
}
extraArgs << ", GLOBAL const mm_long* RESTRICT derivBuffersIn";
for (int i = 0; i < energyDerivs->getNumParameters(); ++i)
reduce << "derivBuffers" << energyDerivs->getParameterSuffix(i, "[index]") <<
" = RECIP((real) 0x100000000)*derivBuffersIn[index+PADDED_NUM_ATOMS*" << cc.intToString(i) << "];\n";
if (needEnergyParamDerivs) {
extraArgs << ", GLOBAL mixed* RESTRICT energyParamDerivs";
const vector& allParamDerivNames = cc.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
saveParamDerivs << "energyParamDerivs[GLOBAL_ID*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
}
}
// Compute the various expressions.
map variables;
variables["x"] = "pos.x";
variables["y"] = "pos.y";
variables["z"] = "pos.z";
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(index)+"]";
}
for (int i = 0; i < numComputedValues; i++)
variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
map expressions;
for (int i = 0; i < force.getNumEnergyTerms(); i++) {
string expression;
CustomGBForce::ComputationType type;
force.getEnergyTermParameters(i, expression, type);
if (type != CustomGBForce::SingleParticle)
continue;
Lepton::ParsedExpression parsed = Lepton::Parser::parse(expression, functions).optimize();
expressions["/*"+cc.intToString(i+1)+"*/ energy += "] = parsed;
for (int j = 0; j < numComputedValues; j++)
expressions["/*"+cc.intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j)+" += "] = energyDerivExpressions[i][j];
Lepton::ParsedExpression gradx = parsed.differentiate("x").optimize();
Lepton::ParsedExpression grady = parsed.differentiate("y").optimize();
Lepton::ParsedExpression gradz = parsed.differentiate("z").optimize();
if (!isZeroExpression(gradx))
expressions["/*"+cc.intToString(i+1)+"*/ force.x -= "] = gradx;
if (!isZeroExpression(grady))
expressions["/*"+cc.intToString(i+1)+"*/ force.y -= "] = grady;
if (!isZeroExpression(gradz))
expressions["/*"+cc.intToString(i+1)+"*/ force.z -= "] = gradz;
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
expressions["/*"+cc.intToString(i+1)+"*/ energyParamDeriv"+cc.intToString(j)+" += "] = energyParamDerivExpressions[i][j];
}
for (int i = 1; i < numComputedValues; i++)
for (int j = 0; j < i; j++)
expressions["real dV"+cc.intToString(i)+"dV"+cc.intToString(j)+" = "] = valueDerivExpressions[i][j];
compute << cc.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "temp");
// Record values.
for (int i = 0; i < (int) energyDerivs->getParameterInfos().size(); i++) {
string index = cc.intToString(i+1);
compute << "derivBuffers" << index << "[index] = deriv" << index << ";\n";
}
compute << "forceBuffers[index] += realToFixedPoint(force.x);\n";
compute << "forceBuffers[index+PADDED_NUM_ATOMS] += realToFixedPoint(force.y);\n";
compute << "forceBuffers[index+PADDED_NUM_ATOMS*2] += realToFixedPoint(force.z);\n";
for (int i = 1; i < numComputedValues; i++) {
compute << "real totalDeriv"<getParameterInfos().size(); i++) {
string index = cc.intToString(i+1);
compute << "derivChain" << index << "[index] = deriv" << index << ";\n";
}
map replacements;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["REDUCE_DERIVATIVES"] = reduce.str();
replacements["COMPUTE_ENERGY"] = compute.str();
replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
map defines;
defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::customGBEnergyPerParticle, replacements), defines);
perParticleEnergyKernel = program->createKernel("computePerParticleEnergy");
}
if (needParameterGradient || needEnergyParamDerivs) {
// Create the kernel to compute chain rule terms for computed values that depend explicitly on particle coordinates, and for
// derivatives with respect to global parameters.
stringstream compute, extraArgs, initParamDerivs, saveParamDerivs;
if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const real* globals";
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1);
extraArgs << ", GLOBAL const " << buffer.getType() << "* RESTRICT " << paramName;
}
for (int i = 0; i < (int) computedValues->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = computedValues->getParameterInfos()[i];
string valueName = "values"+cc.intToString(i+1);
extraArgs << ", GLOBAL const " << buffer.getType() << "* RESTRICT " << valueName;
}
for (int i = 0; i < (int) energyDerivs->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = energyDerivs->getParameterInfos()[i];
string index = cc.intToString(i+1);
extraArgs << ", GLOBAL " << buffer.getType() << "* RESTRICT derivBuffers" << index;
compute << buffer.getType() << " deriv" << index << " = derivBuffers" << index << "[index];\n";
}
if (needEnergyParamDerivs) {
extraArgs << ", GLOBAL mixed* RESTRICT energyParamDerivs";
const vector& allParamDerivNames = cc.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
for (int j = 0; j < dValuedParam[i]->getParameterInfos().size(); j++)
extraArgs << ", GLOBAL real* RESTRICT dValuedParam_" << j << "_" << i;
initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
saveParamDerivs << "energyParamDerivs[GLOBAL_ID*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
}
}
map variables;
variables["x"] = "pos.x";
variables["y"] = "pos.y";
variables["z"] = "pos.z";
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(index)+"]";
}
for (int i = 0; i < numComputedValues; i++)
variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
if (needParameterGradient) {
for (int i = 1; i < numComputedValues; i++) {
string is = cc.intToString(i);
compute << "real3 dV"< derivExpressions;
string js = cc.intToString(j);
derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
compute << cc.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js);
compute << "dV"< gradientExpressions;
if (!isZeroExpression(valueGradientExpressions[i][0]))
gradientExpressions["dV"+is+"dR.x += "] = valueGradientExpressions[i][0];
if (!isZeroExpression(valueGradientExpressions[i][1]))
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cc.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "gradtemp_"+is);
}
for (int i = 1; i < numComputedValues; i++)
compute << "force -= deriv"<getParameterSuffix(i)<<"*dV"<getParameterSuffix(i)<<"*dValuedParam_"< replacements;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["COMPUTE_FORCES"] = compute.str();
replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
map defines;
defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::customGBGradientChainRule, replacements), defines);
gradientChainRuleKernel = program->createKernel("computeGradientChainRuleTerms");
}
{
// Create the code to calculate chain rule terms as part of the default nonbonded kernel.
vector > globalVariables;
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i);
int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
string value = "globals["+cc.intToString(index)+"]";
globalVariables.push_back(makeVariable(name, prefix+value));
}
vector > variables = globalVariables;
map rename;
ExpressionTreeNode rnode(new Operation::Variable("r"));
variables.push_back(make_pair(rnode, "r"));
variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
const string& name = force.getPerParticleParameterName(i);
variables.push_back(makeVariable(name+"1", "((real) "+prefix+"params"+params->getParameterSuffix(i, "1)")));
variables.push_back(makeVariable(name+"2", "((real) "+prefix+"params"+params->getParameterSuffix(i, "2)")));
rename[name+"1"] = name+"2";
rename[name+"2"] = name+"1";
}
map derivExpressions;
stringstream chainSource;
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["real dV0dR1 = "] = dVdR;
derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << cc.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, prefix+"temp0_");
if (needChainForValue[0]) {
if (useExclusionsForValue)
chainSource << "if (!isExcluded) {\n";
chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n";
chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n";
if (useExclusionsForValue)
chainSource << "}\n";
}
for (int i = 1; i < numComputedValues; i++) {
if (needChainForValue[i]) {
chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n";
chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "2") << ";\n";
}
}
map replacements;
string chainStr = chainSource.str();
replacements["COMPUTE_FORCE"] = chainStr;
string source = cc.replaceStrings(CommonKernelSources::customGBChainRule, replacements);
vector parameters;
vector arguments;
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = prefix+"params"+cc.intToString(i+1);
if (chainStr.find(paramName+"1") != chainStr.npos || chainStr.find(paramName+"2") != chainStr.npos)
parameters.push_back(ComputeParameterInfo(buffer.getArray(), paramName, buffer.getComponentType(), buffer.getNumComponents()));
}
for (int i = 0; i < (int) computedValues->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = computedValues->getParameterInfos()[i];
string paramName = prefix+"values"+cc.intToString(i+1);
if (chainStr.find(paramName+"1") != chainStr.npos || chainStr.find(paramName+"2") != chainStr.npos)
parameters.push_back(ComputeParameterInfo(buffer.getArray(), paramName, buffer.getComponentType(), buffer.getNumComponents()));
}
for (int i = 0; i < (int) energyDerivChain->getParameterInfos().size(); i++) {
if (needChainForValue[i]) {
ComputeParameterInfo& buffer = energyDerivChain->getParameterInfos()[i];
string paramName = prefix+"dEdV"+cc.intToString(i+1);
parameters.push_back(ComputeParameterInfo(buffer.getArray(), paramName, buffer.getComponentType(), buffer.getNumComponents()));
}
}
if (needGlobalParams)
arguments.push_back(ComputeParameterInfo(cc.getGlobalParamValues(), prefix+"globals", "real", 1));
nb.addInteraction(useCutoff, usePeriodic, force.getNumExclusions() > 0, cutoff, exclusionList, source, force.getForceGroup());
for (auto param : parameters)
nb.addParameter(param);
for (auto arg : arguments)
nb.addArgument(arg);
}
info = new ForceInfo(force);
cc.addForce(info);
cc.addAutoclearBuffer(longEnergyDerivs);
}
double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
NonbondedUtilities& nb = cc.getNonbondedUtilities();
if (!hasInitializedKernels) {
hasInitializedKernels = true;
// These two kernels can't be compiled in initialize(), because the nonbonded utilities object
// has not yet been initialized then.
{
int numExclusionTiles = nb.getExclusionTiles().getSize();
pairValueDefines["NUM_TILES_WITH_EXCLUSIONS"] = cc.intToString(numExclusionTiles);
int numContexts = cc.getNumContexts();
int startExclusionIndex = cc.getContextIndex()*numExclusionTiles/numContexts;
int endExclusionIndex = (cc.getContextIndex()+1)*numExclusionTiles/numContexts;
pairValueDefines["FIRST_EXCLUSION_TILE"] = cc.intToString(startExclusionIndex);
pairValueDefines["LAST_EXCLUSION_TILE"] = cc.intToString(endExclusionIndex);
pairValueDefines["CUTOFF"] = cc.doubleToString(cutoff);
ComputeProgram program = cc.compileProgram(pairValueSrc, pairValueDefines);
pairValueKernel = program->createKernel("computeN2Value");
pairValueSrc = "";
pairValueDefines.clear();
}
{
int numExclusionTiles = nb.getExclusionTiles().getSize();
pairEnergyDefines["NUM_TILES_WITH_EXCLUSIONS"] = cc.intToString(numExclusionTiles);
int numContexts = cc.getNumContexts();
int startExclusionIndex = cc.getContextIndex()*numExclusionTiles/numContexts;
int endExclusionIndex = (cc.getContextIndex()+1)*numExclusionTiles/numContexts;
pairEnergyDefines["FIRST_EXCLUSION_TILE"] = cc.intToString(startExclusionIndex);
pairEnergyDefines["LAST_EXCLUSION_TILE"] = cc.intToString(endExclusionIndex);
pairEnergyDefines["CUTOFF"] = cc.doubleToString(cutoff);
ComputeProgram program = cc.compileProgram(pairEnergySrc, pairEnergyDefines);
pairEnergyKernel = program->createKernel("computeN2Energy");
pairEnergySrc = "";
pairEnergyDefines.clear();
}
// Set arguments for kernels.
maxTiles = (nb.getUseCutoff() ? nb.getInteractingTiles().getSize() : 0);
int numAtomBlocks = cc.getPaddedNumAtoms()/32;
pairValueKernel->addArg(cc.getPosq());
pairValueKernel->addArg(cc.getNonbondedUtilities().getExclusions());
pairValueKernel->addArg(cc.getNonbondedUtilities().getExclusionTiles());
pairValueKernel->addArg(valueBuffers);
if (nb.getUseCutoff()) {
pairValueKernel->addArg(nb.getInteractingTiles());
pairValueKernel->addArg(nb.getInteractionCount());
for (int i = 0; i < 5; i++)
pairValueKernel->addArg(); // Periodic box size arguments are set when the kernel is executed.
pairValueKernel->addArg(maxTiles);
pairValueKernel->addArg(nb.getBlockCenters());
pairValueKernel->addArg(nb.getBlockBoundingBoxes());
pairValueKernel->addArg(nb.getInteractingAtoms());
}
else
pairValueKernel->addArg(numAtomBlocks*(numAtomBlocks+1)/2);
if (needGlobalParams)
pairValueKernel->addArg(cc.getGlobalParamValues());
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
if (pairValueUsesParam[i]) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i];
pairValueKernel->addArg(buffer.getArray());
}
}
for (auto& d : dValue0dParam)
pairValueKernel->addArg(d);
for (auto& function : tabulatedFunctionArrays)
pairValueKernel->addArg(function);
perParticleValueKernel->addArg(cc.getPosq());
perParticleValueKernel->addArg(valueBuffers);
if (needGlobalParams)
perParticleValueKernel->addArg(cc.getGlobalParamValues());
for (auto& buffer : params->getParameterInfos())
perParticleValueKernel->addArg(buffer.getArray());
for (auto& buffer : computedValues->getParameterInfos())
perParticleValueKernel->addArg(buffer.getArray());
for (int i = 0; i < dValuedParam.size(); i++) {
perParticleValueKernel->addArg(dValue0dParam[i]);
for (int j = 0; j < dValuedParam[i]->getParameterInfos().size(); j++)
perParticleValueKernel->addArg(dValuedParam[i]->getParameterInfos()[j].getArray());
}
for (auto& function : tabulatedFunctionArrays)
perParticleValueKernel->addArg(function);
pairEnergyKernel->addArg(cc.getLongForceBuffer());
pairEnergyKernel->addArg(cc.getEnergyBuffer());
pairEnergyKernel->addArg(cc.getPosq());
pairEnergyKernel->addArg(cc.getNonbondedUtilities().getExclusions());
pairEnergyKernel->addArg(cc.getNonbondedUtilities().getExclusionTiles());
pairEnergyKernel->addArg(); // Whether to include energy.
if (nb.getUseCutoff()) {
pairEnergyKernel->addArg(nb.getInteractingTiles());
pairEnergyKernel->addArg(nb.getInteractionCount());
for (int i = 0; i < 5; i++)
pairEnergyKernel->addArg(); // Periodic box size arguments are set when the kernel is executed.
pairEnergyKernel->addArg(maxTiles);
pairEnergyKernel->addArg(nb.getBlockCenters());
pairEnergyKernel->addArg(nb.getBlockBoundingBoxes());
pairEnergyKernel->addArg(nb.getInteractingAtoms());
}
else
pairEnergyKernel->addArg(numAtomBlocks*(numAtomBlocks+1)/2);
if (needGlobalParams)
pairEnergyKernel->addArg(cc.getGlobalParamValues());
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
if (pairEnergyUsesParam[i]) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i];
pairEnergyKernel->addArg(buffer.getArray());
}
}
for (int i = 0; i < (int) computedValues->getParameterInfos().size(); i++) {
if (pairEnergyUsesValue[i]) {
ComputeParameterInfo& buffer = computedValues->getParameterInfos()[i];
pairEnergyKernel->addArg(buffer.getArray());
}
}
pairEnergyKernel->addArg(longEnergyDerivs);
if (needEnergyParamDerivs)
pairEnergyKernel->addArg(cc.getEnergyParamDerivBuffer());
for (auto& function : tabulatedFunctionArrays)
pairEnergyKernel->addArg(function);
perParticleEnergyKernel->addArg(cc.getEnergyBuffer());
perParticleEnergyKernel->addArg(cc.getPosq());
perParticleEnergyKernel->addArg(cc.getLongForceBuffer());
if (needGlobalParams)
perParticleEnergyKernel->addArg(cc.getGlobalParamValues());
for (auto& buffer : params->getParameterInfos())
perParticleEnergyKernel->addArg(buffer.getArray());
for (auto& buffer : computedValues->getParameterInfos())
perParticleEnergyKernel->addArg(buffer.getArray());
for (auto& buffer : energyDerivs->getParameterInfos())
perParticleEnergyKernel->addArg(buffer.getArray());
for (auto& buffer : energyDerivChain->getParameterInfos())
perParticleEnergyKernel->addArg(buffer.getArray());
perParticleEnergyKernel->addArg(longEnergyDerivs);
if (needEnergyParamDerivs)
perParticleEnergyKernel->addArg(cc.getEnergyParamDerivBuffer());
for (auto& function : tabulatedFunctionArrays)
perParticleEnergyKernel->addArg(function);
if (needParameterGradient || needEnergyParamDerivs) {
gradientChainRuleKernel->addArg(cc.getPosq());
gradientChainRuleKernel->addArg(cc.getLongForceBuffer());
if (needGlobalParams)
gradientChainRuleKernel->addArg(cc.getGlobalParamValues());
for (auto& buffer : params->getParameterInfos())
gradientChainRuleKernel->addArg(buffer.getArray());
for (auto& buffer : computedValues->getParameterInfos())
gradientChainRuleKernel->addArg(buffer.getArray());
for (auto& buffer : energyDerivs->getParameterInfos())
gradientChainRuleKernel->addArg(buffer.getArray());
if (needEnergyParamDerivs) {
gradientChainRuleKernel->addArg(cc.getEnergyParamDerivBuffer());
for (auto d : dValuedParam)
for (auto& buffer : d->getParameterInfos())
gradientChainRuleKernel->addArg(buffer.getArray());
}
for (auto& function : tabulatedFunctionArrays)
gradientChainRuleKernel->addArg(function);
}
}
pairEnergyKernel->setArg(5, (int) includeEnergy);
if (nb.getUseCutoff()) {
setPeriodicBoxArgs(cc, pairValueKernel, 6);
setPeriodicBoxArgs(cc, pairEnergyKernel, 8);
if (maxTiles < nb.getInteractingTiles().getSize()) {
maxTiles = nb.getInteractingTiles().getSize();
pairValueKernel->setArg(11, maxTiles);
pairEnergyKernel->setArg(13, maxTiles);
}
}
pairValueKernel->execute(nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
perParticleValueKernel->execute(cc.getPaddedNumAtoms());
pairEnergyKernel->execute(nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
perParticleEnergyKernel->execute(cc.getPaddedNumAtoms());
if (needParameterGradient || needEnergyParamDerivs)
gradientChainRuleKernel->execute(cc.getPaddedNumAtoms());
return 0.0;
}
void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, const CustomGBForce& force) {
ContextSelector selector(cc);
int numParticles = force.getNumParticles();
if (numParticles != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed");
// Record the per-particle parameters.
vector > paramVector(cc.getPaddedNumAtoms(), vector(force.getNumPerParticleParameters(), 0));
vector parameters;
for (int i = 0; i < numParticles; i++) {
force.getParticleParameters(i, parameters);
for (int j = 0; j < (int) parameters.size(); j++)
paramVector[i][j] = (float) parameters[j];
}
params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width;
vector f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info, true, false);
}