/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2011-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see . *
* -------------------------------------------------------------------------- */
#include "OpenCLBondedUtilities.h"
#include "OpenCLContext.h"
#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "OpenCLNonbondedUtilities.h"
#include
using namespace OpenMM;
using namespace std;
OpenCLBondedUtilities::OpenCLBondedUtilities(OpenCLContext& context) : context(context), maxBonds(0), allGroups(0), hasInitializedKernels(false) {
}
void OpenCLBondedUtilities::addInteraction(const vector >& atoms, const string& source, int group) {
if (atoms.size() > 0) {
forceAtoms.push_back(atoms);
forceSource.push_back(source);
forceGroup.push_back(group);
allGroups |= 1< indexVec(width*numBonds);
for (int bond = 0; bond < numBonds; bond++) {
for (int atom = 0; atom < numAtoms; atom++)
indexVec[bond*width+atom] = forceAtoms[i][bond][atom];
}
atomIndices[i].initialize(context, indexVec.size(), "bondedIndices");
atomIndices[i].upload(indexVec);
}
// Create the kernel.
stringstream s;
for (int i = 0; i < (int) prefixCode.size(); i++)
s< 0)
s<<", __global mixed* restrict energyParamDerivs";
s<<") {\n";
s<<"mixed energy = 0;\n";
for (int i = 0; i < energyParameterDerivatives.size(); i++)
s<<"mixed energyParamDeriv"<& allParamDerivNames = context.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < energyParameterDerivatives.size(); i++)
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == energyParameterDerivatives[i])
s<<"energyParamDerivs[get_global_id(0)*"< defines;
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
cl::Program program = context.createProgram(s.str(), defines);
kernel = cl::Kernel(program, "computeBondedForces");
forceAtoms.clear();
forceSource.clear();
}
string OpenCLBondedUtilities::createForceSource(int forceIndex, int numBonds, int numAtoms, int group, const string& computeForce) {
maxBonds = max(maxBonds, numBonds);
int width = 1;
while (width < numAtoms)
width *= 2;
string suffix1[] = {""};
string suffix4[] = {".x", ".y", ".z", ".w"};
string suffix16[] = {".s0", ".s1", ".s2", ".s3", ".s4", ".s5", ".s6", ".s7",
".s8", ".s9", ".s10", ".s11", ".s12", ".s13", ".s14", ".s15"};
string* suffix;
if (width == 1)
suffix = suffix1;
else if (width <= 4)
suffix = suffix4;
else
suffix = suffix16;
string indexType = "uint"+(width == 1 ? "" : context.intToString(width));
stringstream s;
s<<"if ((groups&"<<(1<(index++, context.getLongForceBuffer().getDeviceBuffer());
kernel.setArg(index++, context.getEnergyBuffer().getDeviceBuffer());
kernel.setArg(index++, context.getPosq().getDeviceBuffer());
index += 6;
for (int j = 0; j < (int) atomIndices.size(); j++)
kernel.setArg(index++, atomIndices[j].getDeviceBuffer());
for (int j = 0; j < (int) arguments.size(); j++)
kernel.setArg(index++, *arguments[j]);
if (energyParameterDerivatives.size() > 0)
kernel.setArg(index++, context.getEnergyParamDerivBuffer().getDeviceBuffer());
}
kernel.setArg(3, groups);
if (context.getUseDoublePrecision()) {
kernel.setArg(4, context.getPeriodicBoxSizeDouble());
kernel.setArg(5, context.getInvPeriodicBoxSizeDouble());
kernel.setArg(6, context.getPeriodicBoxVecXDouble());
kernel.setArg(7, context.getPeriodicBoxVecYDouble());
kernel.setArg(8, context.getPeriodicBoxVecZDouble());
}
else {
kernel.setArg(4, context.getPeriodicBoxSize());
kernel.setArg(5, context.getInvPeriodicBoxSize());
kernel.setArg(6, context.getPeriodicBoxVecX());
kernel.setArg(7, context.getPeriodicBoxVecY());
kernel.setArg(8, context.getPeriodicBoxVecZ());
}
context.executeKernel(kernel, maxBonds);
}