Unverified Commit 5739788a authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Merged more code into common platform (#4346)

* Common implementation of BondedUtilities

* Common implementation of UpdateStateDataKernel
parent 796ffaaa
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2019 Stanford University and the Authors. * * Portions copyright (c) 2011-2023 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -27,19 +27,17 @@ ...@@ -27,19 +27,17 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>. * * along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "openmm/common/ArrayInterface.h" #include "openmm/common/ComputeArray.h"
#include "openmm/common/ComputeKernel.h"
#include "openmm/System.h"
#include <string> #include <string>
#include <vector> #include <vector>
namespace OpenMM { namespace OpenMM {
/** /**
* This abstract class defines an interface for computing bonded interactions. Call
* getBondedUtilities() on a ComputeContext to get the BondedUtilities object for that
* context.
*
* This class provides a generic mechanism for evaluating bonded interactions. You write only * This class provides a generic mechanism for evaluating bonded interactions. You write only
* the source code needed to compute one interaction, and this object takes care of creating * the source code needed to compute one interaction, and this class takes care of creating
* and executing a complete kernel that loops over bonds, evaluates each one, and accumulates * and executing a complete kernel that loops over bonds, evaluates each one, and accumulates
* the resulting forces and energies. This offers two advantages. First, it simplifies the * the resulting forces and energies. This offers two advantages. First, it simplifies the
* task of writing a new Force. Second, it allows multiple forces to be evaluated by a single * task of writing a new Force. Second, it allows multiple forces to be evaluated by a single
...@@ -85,6 +83,7 @@ namespace OpenMM { ...@@ -85,6 +83,7 @@ namespace OpenMM {
class OPENMM_EXPORT_COMMON BondedUtilities { class OPENMM_EXPORT_COMMON BondedUtilities {
public: public:
BondedUtilities(ComputeContext& context);
virtual ~BondedUtilities() { virtual ~BondedUtilities() {
} }
/** /**
...@@ -95,7 +94,7 @@ public: ...@@ -95,7 +94,7 @@ public:
* @param source the code to evaluate the interaction * @param source the code to evaluate the interaction
* @param group the force group in which the interaction should be calculated * @param group the force group in which the interaction should be calculated
*/ */
virtual void addInteraction(const std::vector<std::vector<int> >& atoms, const std::string& source, int group) = 0; void addInteraction(const std::vector<std::vector<int> >& atoms, const std::string& source, int group);
/** /**
* Add an argument that should be passed to the interaction kernel. * Add an argument that should be passed to the interaction kernel.
* *
...@@ -104,7 +103,7 @@ public: ...@@ -104,7 +103,7 @@ public:
* @return the name that will be used for the argument. Any code you pass to addInteraction() should * @return the name that will be used for the argument. Any code you pass to addInteraction() should
* refer to it by this name. * refer to it by this name.
*/ */
virtual std::string addArgument(ArrayInterface& data, const std::string& type) = 0; std::string addArgument(ArrayInterface& data, const std::string& type);
/** /**
* Register that the interaction kernel will be computing the derivative of the potential energy * Register that the interaction kernel will be computing the derivative of the potential energy
* with respect to a parameter. * with respect to a parameter.
...@@ -113,14 +112,39 @@ public: ...@@ -113,14 +112,39 @@ public:
* @return the variable that will be used to accumulate the derivative. Any code you pass to addInteraction() should * @return the variable that will be used to accumulate the derivative. Any code you pass to addInteraction() should
* add its contributions to this variable. * add its contributions to this variable.
*/ */
virtual std::string addEnergyParameterDerivative(const std::string& param) = 0; std::string addEnergyParameterDerivative(const std::string& param);
/** /**
* Add some code that should be included in the program, before the start of the kernel. * Add some code that should be included in the program, before the start of the kernel.
* This can be used, for example, to define functions that will be called by the kernel. * This can be used, for example, to define functions that will be called by the kernel.
* *
* @param source the code to include * @param source the code to include
*/ */
virtual void addPrefixCode(const std::string& source) = 0; void addPrefixCode(const std::string& source);
/**
* Initialize this object in preparation for a simulation.
*/
void initialize(const System& system);
/**
* Compute the bonded interactions.
*
* @param groups a set of bit flags for which force groups to include
*/
void computeInteractions(int groups);
private:
std::string createForceSource(int forceIndex, int numBonds, int numAtoms, int group, const std::string& computeForce);
ComputeContext& context;
ComputeKernel kernel;
std::vector<std::vector<std::vector<int> > > forceAtoms;
std::vector<std::vector<int> > indexWidth;
std::vector<std::string> forceSource;
std::vector<int> forceGroup;
std::vector<ArrayInterface*> arguments;
std::vector<std::string> argTypes;
std::vector<std::vector<ComputeArray> > atomIndices;
std::vector<std::string> prefixCode;
std::vector<std::string> energyParameterDerivatives;
int numForceBuffers, maxBonds, allGroups;
bool hasInitializedKernels, hasInteractions;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -41,6 +41,121 @@ ...@@ -41,6 +41,121 @@
namespace OpenMM { namespace OpenMM {
/**
* This kernel provides methods for setting and retrieving various state data: time, positions,
* velocities, and forces.
*/
class CommonUpdateStateDataKernel : public UpdateStateDataKernel {
public:
CommonUpdateStateDataKernel(std::string name, const Platform& platform, ComputeContext& cc) : UpdateStateDataKernel(name, platform), cc(cc) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* Get the current time (in picoseconds).
*
* @param context the context in which to execute this kernel
*/
double getTime(const ContextImpl& context) const;
/**
* Set the current time (in picoseconds).
*
* @param context the context in which to execute this kernel
*/
void setTime(ContextImpl& context, double time);
/**
* Get the current step count
*
* @param context the context in which to execute this kernel
*/
long long getStepCount(const ContextImpl& context) const;
/**
* Set the current step count
*
* @param context the context in which to execute this kernel
*/
void setStepCount(const ContextImpl& context, long long count);
/**
* Get the positions of all particles.
*
* @param positions on exit, this contains the particle positions
*/
void getPositions(ContextImpl& context, std::vector<Vec3>& positions);
/**
* Set the positions of all particles.
*
* @param positions a vector containg the particle positions
*/
void setPositions(ContextImpl& context, const std::vector<Vec3>& positions);
/**
* Get the velocities of all particles.
*
* @param velocities on exit, this contains the particle velocities
*/
void getVelocities(ContextImpl& context, std::vector<Vec3>& velocities);
/**
* Set the velocities of all particles.
*
* @param velocities a vector containg the particle velocities
*/
void setVelocities(ContextImpl& context, const std::vector<Vec3>& velocities);
/**
* Compute velocities, shifted in time to account for a leapfrog integrator. The shift
* is based on the most recently computed forces.
*
* @param context the context in which to execute this kernel
* @param timeShift the amount by which to shift the velocities in time
* @param velocities the shifted velocities are returned in this
*/
void computeShiftedVelocities(ContextImpl& context, double timeShift, std::vector<Vec3>& velocities);
/**
* Get the current forces on all particles.
*
* @param forces on exit, this contains the forces
*/
void getForces(ContextImpl& context, std::vector<Vec3>& forces);
/**
* Get the current derivatives of the energy with respect to context parameters.
*
* @param derivs on exit, this contains the derivatives
*/
void getEnergyParameterDerivatives(ContextImpl& context, std::map<std::string, double>& derivs);
/**
* Get the current periodic box vectors.
*
* @param a on exit, this contains the vector defining the first edge of the periodic box
* @param b on exit, this contains the vector defining the second edge of the periodic box
* @param c on exit, this contains the vector defining the third edge of the periodic box
*/
void getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const;
/**
* Set the current periodic box vectors.
*
* @param a the vector defining the first edge of the periodic box
* @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box
*/
void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c);
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private:
ComputeContext& cc;
};
/** /**
* This kernel modifies the positions of particles to enforce distance constraints. * This kernel modifies the positions of particles to enforce distance constraints.
*/ */
......
...@@ -129,6 +129,12 @@ public: ...@@ -129,6 +129,12 @@ public:
* one ComputeContext is created for each device. * one ComputeContext is created for each device.
*/ */
virtual int getContextIndex() const = 0; virtual int getContextIndex() const = 0;
/**
* Get a list of all contexts being used for the current simulation.
* This is relevant when a simulation is parallelized across multiple devices. In that case,
* one ComputeContext is created for each device.
*/
virtual std::vector<ComputeContext*> getAllContexts() = 0;
/** /**
* Construct an uninitialized array of the appropriate class for this platform. The returned * Construct an uninitialized array of the appropriate class for this platform. The returned
* value should be created on the heap with the "new" operator. * value should be created on the heap with the "new" operator.
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2019 Stanford University and the Authors. * * Portions copyright (c) 2011-2023 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -24,21 +24,18 @@ ...@@ -24,21 +24,18 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>. * * along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "CudaBondedUtilities.h" #include "openmm/common/BondedUtilities.h"
#include "CudaContext.h" #include "openmm/common/ComputeContext.h"
#include "CudaExpressionUtilities.h"
#include "CudaKernelSources.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "CudaNonbondedUtilities.h"
#include <iostream> #include <iostream>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
CudaBondedUtilities::CudaBondedUtilities(CudaContext& context) : context(context), numForceBuffers(0), maxBonds(0), allGroups(0), hasInitializedKernels(false) { BondedUtilities::BondedUtilities(ComputeContext& context) : context(context), numForceBuffers(0), maxBonds(0), allGroups(0), hasInitializedKernels(false) {
} }
void CudaBondedUtilities::addInteraction(const vector<vector<int> >& atoms, const string& source, int group) { void BondedUtilities::addInteraction(const vector<vector<int> >& atoms, const string& source, int group) {
if (atoms.size() > 0) { if (atoms.size() > 0) {
forceAtoms.push_back(atoms); forceAtoms.push_back(atoms);
forceSource.push_back(source); forceSource.push_back(source);
...@@ -47,17 +44,13 @@ void CudaBondedUtilities::addInteraction(const vector<vector<int> >& atoms, cons ...@@ -47,17 +44,13 @@ void CudaBondedUtilities::addInteraction(const vector<vector<int> >& atoms, cons
} }
} }
string CudaBondedUtilities::addArgument(CUdeviceptr data, const string& type) { string BondedUtilities::addArgument(ArrayInterface& data, const string& type) {
arguments.push_back(data); arguments.push_back(&data);
argTypes.push_back(type); argTypes.push_back(type);
return "customArg"+context.intToString(arguments.size()); return "customArg"+context.intToString(arguments.size());
} }
string CudaBondedUtilities::addArgument(ArrayInterface& data, const string& type) { string BondedUtilities::addEnergyParameterDerivative(const string& param) {
return addArgument(context.unwrap(data).getDevicePointer(), type);
}
string CudaBondedUtilities::addEnergyParameterDerivative(const string& param) {
// See if the parameter has already been added. // See if the parameter has already been added.
int index; int index;
...@@ -70,14 +63,14 @@ string CudaBondedUtilities::addEnergyParameterDerivative(const string& param) { ...@@ -70,14 +63,14 @@ string CudaBondedUtilities::addEnergyParameterDerivative(const string& param) {
return string("energyParamDeriv")+context.intToString(index); return string("energyParamDeriv")+context.intToString(index);
} }
void CudaBondedUtilities::addPrefixCode(const string& source) { void BondedUtilities::addPrefixCode(const string& source) {
for (int i = 0; i < (int) prefixCode.size(); i++) for (int i = 0; i < (int) prefixCode.size(); i++)
if (prefixCode[i] == source) if (prefixCode[i] == source)
return; return;
prefixCode.push_back(source); prefixCode.push_back(source);
} }
void CudaBondedUtilities::initialize(const System& system) { void BondedUtilities::initialize(const System& system) {
int numForces = forceAtoms.size(); int numForces = forceAtoms.size();
hasInteractions = (numForces > 0); hasInteractions = (numForces > 0);
if (!hasInteractions) if (!hasInteractions)
...@@ -109,53 +102,54 @@ void CudaBondedUtilities::initialize(const System& system) { ...@@ -109,53 +102,54 @@ void CudaBondedUtilities::initialize(const System& system) {
// Create the kernel. // Create the kernel.
stringstream s; stringstream s;
s<<CudaKernelSources::vectorOps;
for (int i = 0; i < (int) prefixCode.size(); i++) for (int i = 0; i < (int) prefixCode.size(); i++)
s<<prefixCode[i]; s<<prefixCode[i];
s<<"extern \"C\" __global__ void computeBondedForces(unsigned long long* __restrict__ forceBuffer, mixed* __restrict__ energyBuffer, const real4* __restrict__ posq, int groups, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ"; s<<"KERNEL void computeBondedForces(GLOBAL mm_ulong* RESTRICT forceBuffer, GLOBAL mixed* RESTRICT energyBuffer, GLOBAL const real4* RESTRICT posq, int groups, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ";
for (int force = 0; force < numForces; force++) { for (int force = 0; force < numForces; force++) {
for (int i = 0; i < (int) atomIndices[force].size(); i++) { for (int i = 0; i < (int) atomIndices[force].size(); i++) {
int indexWidth = atomIndices[force][i].getElementSize()/4; int indexWidth = atomIndices[force][i].getElementSize()/4;
string indexType = "uint"+context.intToString(indexWidth); string indexType = (indexWidth == 1 ? "unsigned int" : "uint"+context.intToString(indexWidth));
s<<", const "<<indexType<<"* __restrict__ atomIndices"<<force<<"_"<<i; s<<", GLOBAL const "<<indexType<<"* RESTRICT atomIndices"<<force<<"_"<<i;
} }
} }
for (int i = 0; i < (int) arguments.size(); i++) for (int i = 0; i < (int) arguments.size(); i++)
s<<", "<<argTypes[i]<<"* customArg"<<(i+1); s<<", GLOBAL "<<argTypes[i]<<"* customArg"<<(i+1);
if (energyParameterDerivatives.size() > 0) if (energyParameterDerivatives.size() > 0)
s<<", mixed* __restrict__ energyParamDerivs"; s<<", GLOBAL mixed* RESTRICT energyParamDerivs";
s<<") {\n"; s<<") {\n";
s<<"mixed energy = 0;\n"; s<<"mixed energy = 0;\n";
for (int i = 0; i < energyParameterDerivatives.size(); i++) for (int i = 0; i < energyParameterDerivatives.size(); i++)
s<<"mixed energyParamDeriv"<<i<<" = 0;\n"; s<<"mixed energyParamDeriv"<<i<<" = 0;\n";
for (int force = 0; force < numForces; force++) for (int force = 0; force < numForces; force++)
s<<createForceSource(force, forceAtoms[force].size(), forceAtoms[force][0].size(), forceGroup[force], forceSource[force]); s<<createForceSource(force, forceAtoms[force].size(), forceAtoms[force][0].size(), forceGroup[force], forceSource[force]);
s<<"energyBuffer[blockIdx.x*blockDim.x+threadIdx.x] += energy;\n"; s<<"energyBuffer[GLOBAL_ID] += energy;\n";
const vector<string>& allParamDerivNames = context.getEnergyParamDerivNames(); const vector<string>& allParamDerivNames = context.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size(); int numDerivs = allParamDerivNames.size();
for (int i = 0; i < energyParameterDerivatives.size(); i++) for (int i = 0; i < energyParameterDerivatives.size(); i++)
for (int index = 0; index < numDerivs; index++) for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == energyParameterDerivatives[i]) if (allParamDerivNames[index] == energyParameterDerivatives[i])
s<<"energyParamDerivs[(blockIdx.x*blockDim.x+threadIdx.x)*"<<numDerivs<<"+"<<index<<"] += energyParamDeriv"<<i<<";\n"; s<<"energyParamDerivs[(GLOBAL_ID)*"<<numDerivs<<"+"<<index<<"] += energyParamDeriv"<<i<<";\n";
s<<"}\n"; s<<"}\n";
map<string, string> defines; map<string, string> defines;
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms()); defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
CUmodule module = context.createModule(s.str(), defines); ComputeProgram program = context.compileProgram(s.str(), defines);
kernel = context.getKernel(module, "computeBondedForces"); kernel = program->createKernel("computeBondedForces");
forceAtoms.clear(); forceAtoms.clear();
forceSource.clear(); forceSource.clear();
} }
string CudaBondedUtilities::createForceSource(int forceIndex, int numBonds, int numAtoms, int group, const string& computeForce) { string BondedUtilities::createForceSource(int forceIndex, int numBonds, int numAtoms, int group, const string& computeForce) {
maxBonds = max(maxBonds, numBonds); maxBonds = max(maxBonds, numBonds);
string suffix[] = {".x", ".y", ".z", ".w"}; string suffix1[] = {""};
string suffix4[] = {".x", ".y", ".z", ".w"};
stringstream s; stringstream s;
s<<"if ((groups&"<<(1<<group)<<") != 0)\n"; s<<"if ((groups&"<<(1<<group)<<") != 0)\n";
s<<"for (unsigned int index = blockIdx.x*blockDim.x+threadIdx.x; index < "<<numBonds<<"; index += blockDim.x*gridDim.x) {\n"; s<<"for (unsigned int index = GLOBAL_ID; index < "<<numBonds<<"; index += GLOBAL_SIZE) {\n";
int startAtom = 0; int startAtom = 0;
for (int i = 0; i < (int) atomIndices[forceIndex].size(); i++) { for (int i = 0; i < (int) atomIndices[forceIndex].size(); i++) {
int indexWidth = atomIndices[forceIndex][i].getElementSize()/4; int indexWidth = atomIndices[forceIndex][i].getElementSize()/4;
string indexType = "uint"+context.intToString(indexWidth); string* suffix = (indexWidth == 1 ? suffix1 : suffix4);
string indexType = (indexWidth == 1 ? "unsigned int" : "uint"+context.intToString(indexWidth));
s<<" "<<indexType<<" atoms"<<i<<" = atomIndices"<<forceIndex<<"_"<<i<<"[index];\n"; s<<" "<<indexType<<" atoms"<<i<<" = atomIndices"<<forceIndex<<"_"<<i<<"[index];\n";
int atomsToLoad = min(indexWidth, numAtoms-startAtom); int atomsToLoad = min(indexWidth, numAtoms-startAtom);
for (int j = 0; j < atomsToLoad; j++) { for (int j = 0; j < atomsToLoad; j++) {
...@@ -166,39 +160,51 @@ string CudaBondedUtilities::createForceSource(int forceIndex, int numBonds, int ...@@ -166,39 +160,51 @@ string CudaBondedUtilities::createForceSource(int forceIndex, int numBonds, int
} }
s<<computeForce<<"\n"; s<<computeForce<<"\n";
for (int i = 0; i < numAtoms; i++) { for (int i = 0; i < numAtoms; i++) {
s<<" atomicAdd(&forceBuffer[atom"<<(i+1)<<"], static_cast<unsigned long long>(realToFixedPoint(force"<<(i+1)<<".x)));\n"; s<<" ATOMIC_ADD(&forceBuffer[atom"<<(i+1)<<"], (mm_ulong) realToFixedPoint(force"<<(i+1)<<".x));\n";
s<<" atomicAdd(&forceBuffer[atom"<<(i+1)<<"+PADDED_NUM_ATOMS], static_cast<unsigned long long>(realToFixedPoint(force"<<(i+1)<<".y)));\n"; s<<" ATOMIC_ADD(&forceBuffer[atom"<<(i+1)<<"+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(force"<<(i+1)<<".y));\n";
s<<" atomicAdd(&forceBuffer[atom"<<(i+1)<<"+PADDED_NUM_ATOMS*2], static_cast<unsigned long long>(realToFixedPoint(force"<<(i+1)<<".z)));\n"; s<<" ATOMIC_ADD(&forceBuffer[atom"<<(i+1)<<"+PADDED_NUM_ATOMS*2], (mm_ulong) realToFixedPoint(force"<<(i+1)<<".z));\n";
s<<" __threadfence_block();\n"; s<<" MEM_FENCE;\n";
} }
s<<"}\n"; s<<"}\n";
return s.str(); return s.str();
} }
void CudaBondedUtilities::computeInteractions(int groups) { void BondedUtilities::computeInteractions(int groups) {
if ((groups&allGroups) == 0) if ((groups&allGroups) == 0)
return; return;
if (!hasInitializedKernels) { if (!hasInitializedKernels) {
hasInitializedKernels = true; hasInitializedKernels = true;
kernelArgs.push_back(&context.getForce().getDevicePointer()); kernel->addArg(context.getLongForceBuffer());
kernelArgs.push_back(&context.getEnergyBuffer().getDevicePointer()); kernel->addArg(context.getEnergyBuffer());
kernelArgs.push_back(&context.getPosq().getDevicePointer()); kernel->addArg(context.getPosq());
kernelArgs.push_back(NULL); for (int i = 0; i < 6; i++)
kernelArgs.push_back(context.getPeriodicBoxSizePointer()); kernel->addArg();
kernelArgs.push_back(context.getInvPeriodicBoxSizePointer());
kernelArgs.push_back(context.getPeriodicBoxVecXPointer());
kernelArgs.push_back(context.getPeriodicBoxVecYPointer());
kernelArgs.push_back(context.getPeriodicBoxVecZPointer());
for (int i = 0; i < (int) atomIndices.size(); i++) for (int i = 0; i < (int) atomIndices.size(); i++)
for (int j = 0; j < (int) atomIndices[i].size(); j++) for (int j = 0; j < (int) atomIndices[i].size(); j++)
kernelArgs.push_back(&atomIndices[i][j].getDevicePointer()); kernel->addArg(atomIndices[i][j]);
for (int i = 0; i < (int) arguments.size(); i++) for (int i = 0; i < (int) arguments.size(); i++)
kernelArgs.push_back(&arguments[i]); kernel->addArg(*arguments[i]);
if (energyParameterDerivatives.size() > 0) if (energyParameterDerivatives.size() > 0)
kernelArgs.push_back(&context.getEnergyParamDerivBuffer().getDevicePointer()); kernel->addArg(context.getEnergyParamDerivBuffer());
} }
if (!hasInteractions) if (!hasInteractions)
return; return;
kernelArgs[3] = &groups; kernel->setArg(3, groups);
context.executeKernel(kernel, &kernelArgs[0], maxBonds); Vec3 a, b, c;
context.getPeriodicBoxVectors(a, b, c);
if (context.getUseDoublePrecision()) {
kernel->setArg(4, mm_double4(a[0], b[1], c[2], 0.0));
kernel->setArg(5, mm_double4(1.0/a[0], 1.0/b[1], 1.0/c[2], 0.0));
kernel->setArg(6, mm_double4(a[0], a[1], a[2], 0.0));
kernel->setArg(7, mm_double4(b[0], b[1], b[2], 0.0));
kernel->setArg(8, mm_double4(c[0], c[1], c[2], 0.0));
}
else {
kernel->setArg(4, mm_float4((float) a[0], (float) b[1], (float) c[2], 0.0f));
kernel->setArg(5, mm_float4(1.0f/(float) a[0], 1.0f/(float) b[1], 1.0f/(float) c[2], 0.0f));
kernel->setArg(6, mm_float4((float) a[0], (float) a[1], (float) a[2], 0.0f));
kernel->setArg(7, mm_float4((float) b[0], (float) b[1], (float) b[2], 0.0f));
kernel->setArg(8, mm_float4((float) c[0], (float) c[1], (float) c[2], 0.0f));
}
kernel->execute(maxBonds);
} }
...@@ -112,6 +112,341 @@ static void flushPeriodically(ComputeContext& cc) { ...@@ -112,6 +112,341 @@ static void flushPeriodically(ComputeContext& cc) {
#endif #endif
} }
void CommonUpdateStateDataKernel::initialize(const System& system) {
}
double CommonUpdateStateDataKernel::getTime(const ContextImpl& context) const {
return cc.getTime();
}
void CommonUpdateStateDataKernel::setTime(ContextImpl& context, double time) {
for (auto ctx : cc.getAllContexts())
ctx->setTime(time);
}
long long CommonUpdateStateDataKernel::getStepCount(const ContextImpl& context) const {
return cc.getStepCount();
}
void CommonUpdateStateDataKernel::setStepCount(const ContextImpl& context, long long count) {
for (auto ctx : cc.getAllContexts())
ctx->setStepCount(count);
}
void CommonUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) {
ContextSelector selector(cc);
int numParticles = context.getSystem().getNumParticles();
positions.resize(numParticles);
vector<mm_float4> posCorrection;
if (cc.getUseDoublePrecision()) {
mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
cc.getPosq().download(posq);
}
else if (cc.getUseMixedPrecision()) {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
cc.getPosq().download(posq, false);
posCorrection.resize(numParticles);
cc.getPosqCorrection().download(posCorrection);
}
else {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
cc.getPosq().download(posq);
}
// Filling in the output array is done in parallel for speed.
cc.getThreadPool().execute([&] (ThreadPool& threads, int threadIndex) {
// Compute the position of each particle to return to the user. This is done in parallel for speed.
const vector<int>& order = cc.getAtomIndex();
int numParticles = cc.getNumAtoms();
Vec3 boxVectors[3];
cc.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
int numThreads = threads.getNumThreads();
int start = threadIndex*numParticles/numThreads;
int end = (threadIndex+1)*numParticles/numThreads;
if (cc.getUseDoublePrecision()) {
mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
for (int i = start; i < end; ++i) {
mm_double4 pos = posq[i];
mm_int4 offset = cc.getPosCellOffsets()[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
else if (cc.getUseMixedPrecision()) {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
for (int i = start; i < end; ++i) {
mm_float4 pos1 = posq[i];
mm_float4 pos2 = posCorrection[i];
mm_int4 offset = cc.getPosCellOffsets()[i];
positions[order[i]] = Vec3((double)pos1.x+(double)pos2.x, (double)pos1.y+(double)pos2.y, (double)pos1.z+(double)pos2.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
else {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
for (int i = start; i < end; ++i) {
mm_float4 pos = posq[i];
mm_int4 offset = cc.getPosCellOffsets()[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
});
cc.getThreadPool().waitForThreads();
}
void CommonUpdateStateDataKernel::setPositions(ContextImpl& context, const vector<Vec3>& positions) {
ContextSelector selector(cc);
const vector<int>& order = cc.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
if (cc.getUseDoublePrecision()) {
mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
cc.getPosq().download(posq);
for (int i = 0; i < numParticles; ++i) {
mm_double4& pos = posq[i];
const Vec3& p = positions[order[i]];
pos.x = p[0];
pos.y = p[1];
pos.z = p[2];
}
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++)
posq[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
cc.getPosq().upload(posq);
}
else {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
cc.getPosq().download(posq);
for (int i = 0; i < numParticles; ++i) {
mm_float4& pos = posq[i];
const Vec3& p = positions[order[i]];
pos.x = (float) p[0];
pos.y = (float) p[1];
pos.z = (float) p[2];
}
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++)
posq[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
cc.getPosq().upload(posq);
}
if (cc.getUseMixedPrecision()) {
mm_float4* posCorrection = (mm_float4*) cc.getPinnedBuffer();
for (int i = 0; i < numParticles; ++i) {
mm_float4& c = posCorrection[i];
const Vec3& p = positions[order[i]];
c.x = (float) (p[0]-(float)p[0]);
c.y = (float) (p[1]-(float)p[1]);
c.z = (float) (p[2]-(float)p[2]);
c.w = 0;
}
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++)
posCorrection[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
cc.getPosqCorrection().upload(posCorrection);
}
for (auto& offset : cc.getPosCellOffsets())
offset = mm_int4(0, 0, 0, 0);
cc.reorderAtoms();
}
void CommonUpdateStateDataKernel::getVelocities(ContextImpl& context, vector<Vec3>& velocities) {
ContextSelector selector(cc);
const vector<int>& order = cc.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
velocities.resize(numParticles);
if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
mm_double4* velm = (mm_double4*) cc.getPinnedBuffer();
cc.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
mm_double4 vel = velm[i];
velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
}
}
else {
mm_float4* velm = (mm_float4*) cc.getPinnedBuffer();
cc.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
mm_float4 vel = velm[i];
velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
}
}
}
void CommonUpdateStateDataKernel::setVelocities(ContextImpl& context, const vector<Vec3>& velocities) {
ContextSelector selector(cc);
const vector<int>& order = cc.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
mm_double4* velm = (mm_double4*) cc.getPinnedBuffer();
cc.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
mm_double4& vel = velm[i];
const Vec3& p = velocities[order[i]];
vel.x = p[0];
vel.y = p[1];
vel.z = p[2];
}
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++)
velm[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
cc.getVelm().upload(velm);
}
else {
mm_float4* velm = (mm_float4*) cc.getPinnedBuffer();
cc.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
mm_float4& vel = velm[i];
const Vec3& p = velocities[order[i]];
vel.x = p[0];
vel.y = p[1];
vel.z = p[2];
}
for (int i = numParticles; i < cc.getPaddedNumAtoms(); i++)
velm[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
cc.getVelm().upload(velm);
}
}
void CommonUpdateStateDataKernel::computeShiftedVelocities(ContextImpl& context, double timeShift, vector<Vec3>& velocities) {
cc.getIntegrationUtilities().computeShiftedVelocities(timeShift, velocities);
}
void CommonUpdateStateDataKernel::getForces(ContextImpl& context, vector<Vec3>& forces) {
ContextSelector selector(cc);
long long* force = (long long*) cc.getPinnedBuffer();
cc.getLongForceBuffer().download(force);
const vector<int>& order = cc.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
int paddedNumParticles = cc.getPaddedNumAtoms();
forces.resize(numParticles);
double scale = 1.0/(double) 0x100000000LL;
for (int i = 0; i < numParticles; ++i)
forces[order[i]] = Vec3(scale*force[i], scale*force[i+paddedNumParticles], scale*force[i+paddedNumParticles*2]);
}
void CommonUpdateStateDataKernel::getEnergyParameterDerivatives(ContextImpl& context, map<string, double>& derivs) {
ContextSelector selector(cc);
const vector<string>& paramDerivNames = cc.getEnergyParamDerivNames();
int numDerivs = paramDerivNames.size();
if (numDerivs == 0)
return;
derivs = cc.getEnergyParamDerivWorkspace();
ArrayInterface& derivArray = cc.getEnergyParamDerivBuffer();
if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
vector<double> derivBuffers;
derivArray.download(derivBuffers);
for (int i = numDerivs; i < derivArray.getSize(); i += numDerivs)
for (int j = 0; j < numDerivs; j++)
derivBuffers[j] += derivBuffers[i+j];
for (int i = 0; i < numDerivs; i++)
derivs[paramDerivNames[i]] += derivBuffers[i];
}
else {
vector<float> derivBuffers;
derivArray.download(derivBuffers);
for (int i = numDerivs; i < derivArray.getSize(); i += numDerivs)
for (int j = 0; j < numDerivs; j++)
derivBuffers[j] += derivBuffers[i+j];
for (int i = 0; i < numDerivs; i++)
derivs[paramDerivNames[i]] += derivBuffers[i];
}
}
void CommonUpdateStateDataKernel::getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const {
cc.getPeriodicBoxVectors(a, b, c);
}
void CommonUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) {
// If any particles have been wrapped to the first periodic box, we need to unwrap them
// to avoid changing their positions.
vector<Vec3> positions;
for (auto offset : cc.getPosCellOffsets()) {
if (offset.x != 0 || offset.y != 0 || offset.z != 0) {
getPositions(context, positions);
break;
}
}
// Update the vectors.
for (auto ctx : cc.getAllContexts())
ctx->setPeriodicBoxVectors(a, b, c);
if (positions.size() > 0)
setPositions(context, positions);
}
void CommonUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
ContextSelector selector(cc);
int version = 3;
stream.write((char*) &version, sizeof(int));
int precision = (cc.getUseDoublePrecision() ? 2 : cc.getUseMixedPrecision() ? 1 : 0);
stream.write((char*) &precision, sizeof(int));
double time = cc.getTime();
stream.write((char*) &time, sizeof(double));
long long stepCount = cc.getStepCount();
stream.write((char*) &stepCount, sizeof(long long));
int stepsSinceReorder = cc.getStepsSinceReorder();
stream.write((char*) &stepsSinceReorder, sizeof(int));
char* buffer = (char*) cc.getPinnedBuffer();
cc.getPosq().download(buffer);
stream.write(buffer, cc.getPosq().getSize()*cc.getPosq().getElementSize());
if (cc.getUseMixedPrecision()) {
cc.getPosqCorrection().download(buffer);
stream.write(buffer, cc.getPosqCorrection().getSize()*cc.getPosqCorrection().getElementSize());
}
cc.getVelm().download(buffer);
stream.write(buffer, cc.getVelm().getSize()*cc.getVelm().getElementSize());
stream.write((char*) &cc.getAtomIndex()[0], sizeof(int)*cc.getAtomIndex().size());
stream.write((char*) &cc.getPosCellOffsets()[0], sizeof(mm_int4)*cc.getPosCellOffsets().size());
Vec3 boxVectors[3];
cc.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
stream.write((char*) boxVectors, 3*sizeof(Vec3));
cc.getIntegrationUtilities().createCheckpoint(stream);
SimTKOpenMMUtilities::createCheckpoint(stream);
}
void CommonUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
ContextSelector selector(cc);
int version;
stream.read((char*) &version, sizeof(int));
if (version != 3)
throw OpenMMException("Checkpoint was created with a different version of OpenMM");
int precision;
stream.read((char*) &precision, sizeof(int));
int expectedPrecision = (cc.getUseDoublePrecision() ? 2 : cc.getUseMixedPrecision() ? 1 : 0);
if (precision != expectedPrecision)
throw OpenMMException("Checkpoint was created with a different numeric precision");
double time;
stream.read((char*) &time, sizeof(double));
long long stepCount;
stream.read((char*) &stepCount, sizeof(long long));
int stepsSinceReorder;
stream.read((char*) &stepsSinceReorder, sizeof(int));
vector<ComputeContext*> contexts = cc.getAllContexts();
for (auto ctx : contexts) {
ctx->setTime(time);
ctx->setStepCount(stepCount);
ctx->setStepsSinceReorder(stepsSinceReorder);
}
char* buffer = (char*) cc.getPinnedBuffer();
stream.read(buffer, cc.getPosq().getSize()*cc.getPosq().getElementSize());
cc.getPosq().upload(buffer);
if (cc.getUseMixedPrecision()) {
stream.read(buffer, cc.getPosqCorrection().getSize()*cc.getPosqCorrection().getElementSize());
cc.getPosqCorrection().upload(buffer);
}
stream.read(buffer, cc.getVelm().getSize()*cc.getVelm().getElementSize());
cc.getVelm().upload(buffer);
stream.read((char*) &cc.getAtomIndex()[0], sizeof(int)*cc.getAtomIndex().size());
cc.getAtomIndexArray().upload(cc.getAtomIndex());
stream.read((char*) &cc.getPosCellOffsets()[0], sizeof(mm_int4)*cc.getPosCellOffsets().size());
Vec3 boxVectors[3];
stream.read((char*) &boxVectors, 3*sizeof(Vec3));
for (auto ctx : contexts)
ctx->setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
cc.getIntegrationUtilities().loadCheckpoint(stream);
SimTKOpenMMUtilities::loadCheckpoint(stream);
for (auto listener : cc.getReorderListeners())
listener->execute();
cc.validateAtomOrder();
}
void CommonApplyConstraintsKernel::initialize(const System& system) { void CommonApplyConstraintsKernel::initialize(const System& system) {
} }
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2018 Stanford University and the Authors. * * Portions copyright (c) 2023 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -27,134 +27,20 @@ ...@@ -27,134 +27,20 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>. * * along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "CudaArray.h"
#include "openmm/System.h"
#include "openmm/common/BondedUtilities.h" #include "openmm/common/BondedUtilities.h"
#include <string> #include "openmm/common/windowsExportCommon.h"
#include <vector>
namespace OpenMM { namespace OpenMM {
class CudaContext;
/** /**
* This class provides a generic mechanism for evaluating bonded interactions. You write only * This class exists only for backward compatibility. It adds no features beyond
* the source code needed to compute one interaction, and this class takes care of creating * the base BondedUtilities class.
* and executing a complete kernel that loops over bonds, evaluates each one, and accumulates
* the resulting forces and energies. This offers two advantages. First, it simplifies the
* task of writing a new Force. Second, it allows multiple forces to be evaluated by a single
* kernel, which reduces overhead and improves performance.
*
* A "bonded interaction" means an interaction that affects a small, fixed set of particles.
* The interaction energy may depend on the positions of only those particles, and the list of
* particles forming a "bond" may not change with time. Examples of bonded interactions
* include HarmonicBondForce, HarmonicAngleForce, and PeriodicTorsionForce.
*
* To create a bonded interaction, call addInteraction(). You pass to it a block of source
* code for evaluating the interaction. The inputs and outputs for that source code are as
* follows:
*
* <ol>
* <li>The index of the bond being evaluated will have been stored in the unsigned int variable "index".</li>
* <li>The indices of the atoms forming that bond will have been stored in the unsigned int variables "atom1",
* "atom2", ....</li>
* <li>The positions of those atoms will have been stored in the real4 variables "pos1", "pos2", ....</li>
* <li>A real variable called "energy" will exist. Your code should add the potential energy of the
* bond to that variable.</li>
* <li>Your code should define real3 variables called "force1", "force2", ... that contain the force to
* apply to each atom.</li>
* </ol>
*
* As a simple example, the following source code would be used to implement a pairwise interaction of
* the form E=r^2:
*
* \verbatim embed:rst:leading-asterisk
* .. code-block:: cpp
*
* real4 delta = pos2-pos1;
* energy += delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
* real3 force1 = 2.0f*delta;
* real3 force2 = -2.0f*delta;
*
* \endverbatim
*
* Interactions will often depend on parameters or other data. Call addArgument() to provide the data
* to this class. It will be passed to the interaction kernel as an argument, and you can refer to it
* from your interaction code.
*/ */
class OPENMM_EXPORT_COMMON CudaBondedUtilities : public BondedUtilities { class OPENMM_EXPORT_COMMON CudaBondedUtilities : public BondedUtilities {
public: public:
CudaBondedUtilities(CudaContext& context); CudaBondedUtilities(ComputeContext& context) : BondedUtilities(context) {
/** }
* Add a bonded interaction.
*
* @param atoms this should have one entry for each bond, and that entry should contain the list
* of atoms involved in the bond. Every entry must have the same number of atoms.
* @param source the code to evaluate the interaction
* @param group the force group in which the interaction should be calculated
*/
void addInteraction(const std::vector<std::vector<int> >& atoms, const std::string& source, int group);
/**
* Add an argument that should be passed to the interaction kernel.
*
* @param data the device memory containing the data to pass
* @param type the data type contained in the memory (e.g. "float4")
* @return the name that will be used for the argument. Any code you pass to addInteraction() should
* refer to it by this name.
*/
std::string addArgument(CUdeviceptr data, const std::string& type);
/**
* Add an argument that should be passed to the interaction kernel.
*
* @param data the array containing the data to pass
* @param type the data type contained in the memory (e.g. "float4")
* @return the name that will be used for the argument. Any code you pass to addInteraction() should
* refer to it by this name.
*/
std::string addArgument(ArrayInterface& data, const std::string& type);
/**
* Register that the interaction kernel will be computing the derivative of the potential energy
* with respect to a parameter.
*
* @param param the name of the parameter
* @return the variable that will be used to accumulate the derivative. Any code you pass to addInteraction() should
* add its contributions to this variable.
*/
std::string addEnergyParameterDerivative(const std::string& param);
/**
* Add some Cuda code that should be included in the program, before the start of the kernel.
* This can be used, for example, to define functions that will be called by the kernel.
*
* @param source the code to include
*/
void addPrefixCode(const std::string& source);
/**
* Initialize this object in preparation for a simulation.
*/
void initialize(const System& system);
/**
* Compute the bonded interactions.
*
* @param groups a set of bit flags for which force groups to include
*/
void computeInteractions(int groups);
private:
std::string createForceSource(int forceIndex, int numBonds, int numAtoms, int group, const std::string& computeForce);
CudaContext& context;
CUfunction kernel;
std::vector<std::vector<std::vector<int> > > forceAtoms;
std::vector<std::vector<int> > indexWidth;
std::vector<std::string> forceSource;
std::vector<int> forceGroup;
std::vector<CUdeviceptr> arguments;
std::vector<std::string> argTypes;
std::vector<std::vector<CudaArray> > atomIndices;
std::vector<std::string> prefixCode;
std::vector<std::string> energyParameterDerivatives;
std::vector<void*> kernelArgs;
int numForceBuffers, maxBonds, allGroups;
bool hasInitializedKernels, hasInteractions;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -147,6 +147,12 @@ public: ...@@ -147,6 +147,12 @@ public:
int getContextIndex() const { int getContextIndex() const {
return contextIndex; return contextIndex;
} }
/**
* Get a list of all contexts being used for the current simulation.
* This is relevant when a simulation is parallelized across multiple devices. In that case,
* one ComputeContext is created for each device.
*/
std::vector<ComputeContext*> getAllContexts();
/** /**
* Get the stream currently being used for execution. * Get the stream currently being used for execution.
*/ */
......
...@@ -39,31 +39,6 @@ ...@@ -39,31 +39,6 @@
namespace OpenMM { namespace OpenMM {
/**
* This abstract class defines an interface for code that can compile CUDA kernels. This allows a plugin to take advantage of runtime compilation
* when running on recent versions of CUDA.
*/
class CudaCompilerKernel : public KernelImpl {
public:
static std::string Name() {
return "CudaCompilerKernel";
}
CudaCompilerKernel(std::string name, const Platform& platform) : KernelImpl(name, platform) {
}
/**
* Compile a kernel to PTX.
*
* @param source the source code for the kernel
* @param options the flags to be passed to the compiler
* @param cu the CudaContext for which the kernel is being compiled
*/
virtual std::string createModule(const std::string& source, const std::string& flags, CudaContext& cu) = 0;
/**
* Get the maximum architecture version the compiler supports.
*/
virtual int getMaxSupportedArchitecture() const = 0;
};
/** /**
* This kernel is invoked at the beginning and end of force and energy computations. It gives the * This kernel is invoked at the beginning and end of force and energy computations. It gives the
* Platform a chance to clear buffers and do other initialization at the beginning, and to do any * Platform a chance to clear buffers and do other initialization at the beginning, and to do any
...@@ -108,121 +83,6 @@ private: ...@@ -108,121 +83,6 @@ private:
CudaContext& cu; CudaContext& cu;
}; };
/**
* This kernel provides methods for setting and retrieving various state data: time, positions,
* velocities, and forces.
*/
class CudaUpdateStateDataKernel : public UpdateStateDataKernel {
public:
CudaUpdateStateDataKernel(std::string name, const Platform& platform, CudaContext& cu) : UpdateStateDataKernel(name, platform), cu(cu) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* Get the current time (in picoseconds).
*
* @param context the context in which to execute this kernel
*/
double getTime(const ContextImpl& context) const;
/**
* Set the current time (in picoseconds).
*
* @param context the context in which to execute this kernel
*/
void setTime(ContextImpl& context, double time);
/**
* Get the current step count
*
* @param context the context in which to execute this kernel
*/
long long getStepCount(const ContextImpl& context) const;
/**
* Set the current step count
*
* @param context the context in which to execute this kernel
*/
void setStepCount(const ContextImpl& context, long long count);
/**
* Get the positions of all particles.
*
* @param positions on exit, this contains the particle positions
*/
void getPositions(ContextImpl& context, std::vector<Vec3>& positions);
/**
* Set the positions of all particles.
*
* @param positions a vector containg the particle positions
*/
void setPositions(ContextImpl& context, const std::vector<Vec3>& positions);
/**
* Get the velocities of all particles.
*
* @param velocities on exit, this contains the particle velocities
*/
void getVelocities(ContextImpl& context, std::vector<Vec3>& velocities);
/**
* Set the velocities of all particles.
*
* @param velocities a vector containg the particle velocities
*/
void setVelocities(ContextImpl& context, const std::vector<Vec3>& velocities);
/**
* Compute velocities, shifted in time to account for a leapfrog integrator. The shift
* is based on the most recently computed forces.
*
* @param context the context in which to execute this kernel
* @param timeShift the amount by which to shift the velocities in time
* @param velocities the shifted velocities are returned in this
*/
void computeShiftedVelocities(ContextImpl& context, double timeShift, std::vector<Vec3>& velocities);
/**
* Get the current forces on all particles.
*
* @param forces on exit, this contains the forces
*/
void getForces(ContextImpl& context, std::vector<Vec3>& forces);
/**
* Get the current derivatives of the energy with respect to context parameters.
*
* @param derivs on exit, this contains the derivatives
*/
void getEnergyParameterDerivatives(ContextImpl& context, std::map<std::string, double>& derivs);
/**
* Get the current periodic box vectors.
*
* @param a on exit, this contains the vector defining the first edge of the periodic box
* @param b on exit, this contains the vector defining the second edge of the periodic box
* @param c on exit, this contains the vector defining the third edge of the periodic box
*/
void getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const;
/**
* Set the current periodic box vectors.
*
* @param a the vector defining the first edge of the periodic box
* @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box
*/
void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c);
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private:
CudaContext& cu;
};
/** /**
* This kernel is invoked by NonbondedForce to calculate the forces acting on the system. * This kernel is invoked by NonbondedForce to calculate the forces acting on the system.
*/ */
......
...@@ -634,6 +634,13 @@ CUfunction CudaContext::getKernel(CUmodule& module, const string& name) { ...@@ -634,6 +634,13 @@ CUfunction CudaContext::getKernel(CUmodule& module, const string& name) {
return function; return function;
} }
vector<ComputeContext*> CudaContext::getAllContexts() {
vector<ComputeContext*> result;
for (CudaContext* c : platformData.contexts)
result.push_back(c);
return result;
}
CUstream CudaContext::getCurrentStream() { CUstream CudaContext::getCurrentStream() {
return currentStream; return currentStream;
} }
......
...@@ -72,7 +72,7 @@ KernelImpl* CudaKernelFactory::createKernelImpl(std::string name, const Platform ...@@ -72,7 +72,7 @@ KernelImpl* CudaKernelFactory::createKernelImpl(std::string name, const Platform
if (name == CalcForcesAndEnergyKernel::Name()) if (name == CalcForcesAndEnergyKernel::Name())
return new CudaCalcForcesAndEnergyKernel(name, platform, cu); return new CudaCalcForcesAndEnergyKernel(name, platform, cu);
if (name == UpdateStateDataKernel::Name()) if (name == UpdateStateDataKernel::Name())
return new CudaUpdateStateDataKernel(name, platform, cu); return new CommonUpdateStateDataKernel(name, platform, cu);
if (name == ApplyConstraintsKernel::Name()) if (name == ApplyConstraintsKernel::Name())
return new CommonApplyConstraintsKernel(name, platform, cu); return new CommonApplyConstraintsKernel(name, platform, cu);
if (name == VirtualSitesKernel::Name()) if (name == VirtualSitesKernel::Name())
......
...@@ -86,347 +86,6 @@ double CudaCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bo ...@@ -86,347 +86,6 @@ double CudaCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bo
return sum; return sum;
} }
void CudaUpdateStateDataKernel::initialize(const System& system) {
}
double CudaUpdateStateDataKernel::getTime(const ContextImpl& context) const {
return cu.getTime();
}
void CudaUpdateStateDataKernel::setTime(ContextImpl& context, double time) {
vector<CudaContext*>& contexts = cu.getPlatformData().contexts;
for (auto ctx : contexts)
ctx->setTime(time);
}
long long CudaUpdateStateDataKernel::getStepCount(const ContextImpl& context) const {
return cu.getStepCount();
}
void CudaUpdateStateDataKernel::setStepCount(const ContextImpl& context, long long count) {
vector<CudaContext*>& contexts = cu.getPlatformData().contexts;
for (auto ctx : contexts)
ctx->setStepCount(count);
}
void CudaUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) {
ContextSelector selector(cu);
int numParticles = context.getSystem().getNumParticles();
positions.resize(numParticles);
vector<float4> posCorrection;
if (cu.getUseDoublePrecision()) {
double4* posq = (double4*) cu.getPinnedBuffer();
cu.getPosq().download(posq);
}
else if (cu.getUseMixedPrecision()) {
float4* posq = (float4*) cu.getPinnedBuffer();
cu.getPosq().download(posq, false);
posCorrection.resize(numParticles);
cu.getPosqCorrection().download(posCorrection);
}
else {
float4* posq = (float4*) cu.getPinnedBuffer();
cu.getPosq().download(posq);
}
// Filling in the output array is done in parallel for speed.
cu.getPlatformData().threads.execute([&] (ThreadPool& threads, int threadIndex) {
// Compute the position of each particle to return to the user. This is done in parallel for speed.
const vector<int>& order = cu.getAtomIndex();
int numParticles = cu.getNumAtoms();
Vec3 boxVectors[3];
cu.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
int numThreads = threads.getNumThreads();
int start = threadIndex*numParticles/numThreads;
int end = (threadIndex+1)*numParticles/numThreads;
if (cu.getUseDoublePrecision()) {
double4* posq = (double4*) cu.getPinnedBuffer();
for (int i = start; i < end; ++i) {
double4 pos = posq[i];
mm_int4 offset = cu.getPosCellOffsets()[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
else if (cu.getUseMixedPrecision()) {
float4* posq = (float4*) cu.getPinnedBuffer();
for (int i = start; i < end; ++i) {
float4 pos1 = posq[i];
float4 pos2 = posCorrection[i];
mm_int4 offset = cu.getPosCellOffsets()[i];
positions[order[i]] = Vec3((double)pos1.x+(double)pos2.x, (double)pos1.y+(double)pos2.y, (double)pos1.z+(double)pos2.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
else {
float4* posq = (float4*) cu.getPinnedBuffer();
for (int i = start; i < end; ++i) {
float4 pos = posq[i];
mm_int4 offset = cu.getPosCellOffsets()[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
});
cu.getPlatformData().threads.waitForThreads();
}
void CudaUpdateStateDataKernel::setPositions(ContextImpl& context, const vector<Vec3>& positions) {
ContextSelector selector(cu);
const vector<int>& order = cu.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
if (cu.getUseDoublePrecision()) {
double4* posq = (double4*) cu.getPinnedBuffer();
cu.getPosq().download(posq);
for (int i = 0; i < numParticles; ++i) {
double4& pos = posq[i];
const Vec3& p = positions[order[i]];
pos.x = p[0];
pos.y = p[1];
pos.z = p[2];
}
for (int i = numParticles; i < cu.getPaddedNumAtoms(); i++)
posq[i] = make_double4(0.0, 0.0, 0.0, 0.0);
cu.getPosq().upload(posq);
}
else {
float4* posq = (float4*) cu.getPinnedBuffer();
cu.getPosq().download(posq);
for (int i = 0; i < numParticles; ++i) {
float4& pos = posq[i];
const Vec3& p = positions[order[i]];
pos.x = (float) p[0];
pos.y = (float) p[1];
pos.z = (float) p[2];
}
for (int i = numParticles; i < cu.getPaddedNumAtoms(); i++)
posq[i] = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
cu.getPosq().upload(posq);
}
if (cu.getUseMixedPrecision()) {
float4* posCorrection = (float4*) cu.getPinnedBuffer();
for (int i = 0; i < numParticles; ++i) {
float4& c = posCorrection[i];
const Vec3& p = positions[order[i]];
c.x = (float) (p[0]-(float)p[0]);
c.y = (float) (p[1]-(float)p[1]);
c.z = (float) (p[2]-(float)p[2]);
c.w = 0;
}
for (int i = numParticles; i < cu.getPaddedNumAtoms(); i++)
posCorrection[i] = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
cu.getPosqCorrection().upload(posCorrection);
}
for (auto& offset : cu.getPosCellOffsets())
offset = mm_int4(0, 0, 0, 0);
cu.reorderAtoms();
}
void CudaUpdateStateDataKernel::getVelocities(ContextImpl& context, vector<Vec3>& velocities) {
ContextSelector selector(cu);
const vector<int>& order = cu.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
velocities.resize(numParticles);
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) {
double4* velm = (double4*) cu.getPinnedBuffer();
cu.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
double4 vel = velm[i];
mm_int4 offset = cu.getPosCellOffsets()[i];
velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
}
}
else {
float4* velm = (float4*) cu.getPinnedBuffer();
cu.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
float4 vel = velm[i];
mm_int4 offset = cu.getPosCellOffsets()[i];
velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
}
}
}
void CudaUpdateStateDataKernel::setVelocities(ContextImpl& context, const vector<Vec3>& velocities) {
ContextSelector selector(cu);
const vector<int>& order = cu.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) {
double4* velm = (double4*) cu.getPinnedBuffer();
cu.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
double4& vel = velm[i];
const Vec3& p = velocities[order[i]];
vel.x = p[0];
vel.y = p[1];
vel.z = p[2];
}
for (int i = numParticles; i < cu.getPaddedNumAtoms(); i++)
velm[i] = make_double4(0.0, 0.0, 0.0, 0.0);
cu.getVelm().upload(velm);
}
else {
float4* velm = (float4*) cu.getPinnedBuffer();
cu.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
float4& vel = velm[i];
const Vec3& p = velocities[order[i]];
vel.x = p[0];
vel.y = p[1];
vel.z = p[2];
}
for (int i = numParticles; i < cu.getPaddedNumAtoms(); i++)
velm[i] = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
cu.getVelm().upload(velm);
}
}
void CudaUpdateStateDataKernel::computeShiftedVelocities(ContextImpl& context, double timeShift, vector<Vec3>& velocities) {
cu.getIntegrationUtilities().computeShiftedVelocities(timeShift, velocities);
}
void CudaUpdateStateDataKernel::getForces(ContextImpl& context, vector<Vec3>& forces) {
ContextSelector selector(cu);
long long* force = (long long*) cu.getPinnedBuffer();
cu.getForce().download(force);
const vector<int>& order = cu.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
int paddedNumParticles = cu.getPaddedNumAtoms();
forces.resize(numParticles);
double scale = 1.0/(double) 0x100000000LL;
for (int i = 0; i < numParticles; ++i)
forces[order[i]] = Vec3(scale*force[i], scale*force[i+paddedNumParticles], scale*force[i+paddedNumParticles*2]);
}
void CudaUpdateStateDataKernel::getEnergyParameterDerivatives(ContextImpl& context, map<string, double>& derivs) {
ContextSelector selector(cu);
const vector<string>& paramDerivNames = cu.getEnergyParamDerivNames();
int numDerivs = paramDerivNames.size();
if (numDerivs == 0)
return;
derivs = cu.getEnergyParamDerivWorkspace();
CudaArray& derivArray = cu.getEnergyParamDerivBuffer();
if (cu.getUseDoublePrecision() || cu.getUseMixedPrecision()) {
vector<double> derivBuffers;
derivArray.download(derivBuffers);
for (int i = numDerivs; i < derivArray.getSize(); i += numDerivs)
for (int j = 0; j < numDerivs; j++)
derivBuffers[j] += derivBuffers[i+j];
for (int i = 0; i < numDerivs; i++)
derivs[paramDerivNames[i]] += derivBuffers[i];
}
else {
vector<float> derivBuffers;
derivArray.download(derivBuffers);
for (int i = numDerivs; i < derivArray.getSize(); i += numDerivs)
for (int j = 0; j < numDerivs; j++)
derivBuffers[j] += derivBuffers[i+j];
for (int i = 0; i < numDerivs; i++)
derivs[paramDerivNames[i]] += derivBuffers[i];
}
}
void CudaUpdateStateDataKernel::getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const {
cu.getPeriodicBoxVectors(a, b, c);
}
void CudaUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) {
vector<CudaContext*>& contexts = cu.getPlatformData().contexts;
// If any particles have been wrapped to the first periodic box, we need to unwrap them
// to avoid changing their positions.
vector<Vec3> positions;
for (auto& offset : cu.getPosCellOffsets()) {
if (offset.x != 0 || offset.y != 0 || offset.z != 0) {
getPositions(context, positions);
break;
}
}
// Update the vectors.
for (auto ctx : contexts)
ctx->setPeriodicBoxVectors(a, b, c);
if (positions.size() > 0)
setPositions(context, positions);
}
void CudaUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
ContextSelector selector(cu);
int version = 3;
stream.write((char*) &version, sizeof(int));
int precision = (cu.getUseDoublePrecision() ? 2 : cu.getUseMixedPrecision() ? 1 : 0);
stream.write((char*) &precision, sizeof(int));
double time = cu.getTime();
stream.write((char*) &time, sizeof(double));
long long stepCount = cu.getStepCount();
stream.write((char*) &stepCount, sizeof(long long));
int stepsSinceReorder = cu.getStepsSinceReorder();
stream.write((char*) &stepsSinceReorder, sizeof(int));
char* buffer = (char*) cu.getPinnedBuffer();
cu.getPosq().download(buffer);
stream.write(buffer, cu.getPosq().getSize()*cu.getPosq().getElementSize());
if (cu.getUseMixedPrecision()) {
cu.getPosqCorrection().download(buffer);
stream.write(buffer, cu.getPosqCorrection().getSize()*cu.getPosqCorrection().getElementSize());
}
cu.getVelm().download(buffer);
stream.write(buffer, cu.getVelm().getSize()*cu.getVelm().getElementSize());
stream.write((char*) &cu.getAtomIndex()[0], sizeof(int)*cu.getAtomIndex().size());
stream.write((char*) &cu.getPosCellOffsets()[0], sizeof(int4)*cu.getPosCellOffsets().size());
Vec3 boxVectors[3];
cu.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
stream.write((char*) boxVectors, 3*sizeof(Vec3));
cu.getIntegrationUtilities().createCheckpoint(stream);
SimTKOpenMMUtilities::createCheckpoint(stream);
}
void CudaUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
ContextSelector selector(cu);
int version;
stream.read((char*) &version, sizeof(int));
if (version != 3)
throw OpenMMException("Checkpoint was created with a different version of OpenMM");
int precision;
stream.read((char*) &precision, sizeof(int));
int expectedPrecision = (cu.getUseDoublePrecision() ? 2 : cu.getUseMixedPrecision() ? 1 : 0);
if (precision != expectedPrecision)
throw OpenMMException("Checkpoint was created with a different numeric precision");
double time;
stream.read((char*) &time, sizeof(double));
long long stepCount;
stream.read((char*) &stepCount, sizeof(long long));
int stepsSinceReorder;
stream.read((char*) &stepsSinceReorder, sizeof(int));
vector<CudaContext*>& contexts = cu.getPlatformData().contexts;
for (auto ctx : contexts) {
ctx->setTime(time);
ctx->setStepCount(stepCount);
ctx->setStepsSinceReorder(stepsSinceReorder);
}
char* buffer = (char*) cu.getPinnedBuffer();
stream.read(buffer, cu.getPosq().getSize()*cu.getPosq().getElementSize());
cu.getPosq().upload(buffer);
if (cu.getUseMixedPrecision()) {
stream.read(buffer, cu.getPosqCorrection().getSize()*cu.getPosqCorrection().getElementSize());
cu.getPosqCorrection().upload(buffer);
}
stream.read(buffer, cu.getVelm().getSize()*cu.getVelm().getElementSize());
cu.getVelm().upload(buffer);
stream.read((char*) &cu.getAtomIndex()[0], sizeof(int)*cu.getAtomIndex().size());
cu.getAtomIndexArray().upload(cu.getAtomIndex());
stream.read((char*) &cu.getPosCellOffsets()[0], sizeof(int4)*cu.getPosCellOffsets().size());
Vec3 boxVectors[3];
stream.read((char*) &boxVectors, 3*sizeof(Vec3));
for (auto ctx : contexts)
ctx->setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
cu.getIntegrationUtilities().loadCheckpoint(stream);
SimTKOpenMMUtilities::loadCheckpoint(stream);
for (auto listener : cu.getReorderListeners())
listener->execute();
cu.validateAtomOrder();
}
class CudaCalcNonbondedForceKernel::ForceInfo : public CudaForceInfo { class CudaCalcNonbondedForceKernel::ForceInfo : public CudaForceInfo {
public: public:
ForceInfo(const NonbondedForce& force) : force(force) { ForceInfo(const NonbondedForce& force) : force(force) {
...@@ -997,7 +656,7 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon ...@@ -997,7 +656,7 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
} }
exclusionAtoms.upload(exclusionAtomsVec); exclusionAtoms.upload(exclusionAtomsVec);
map<string, string> replacements; map<string, string> replacements;
replacements["PARAMS"] = cu.getBondedUtilities().addArgument(exclusionParams.getDevicePointer(), "float4"); replacements["PARAMS"] = cu.getBondedUtilities().addArgument(exclusionParams, "float4");
replacements["EWALD_ALPHA"] = cu.doubleToString(alpha); replacements["EWALD_ALPHA"] = cu.doubleToString(alpha);
replacements["TWO_OVER_SQRT_PI"] = cu.doubleToString(2.0/sqrt(M_PI)); replacements["TWO_OVER_SQRT_PI"] = cu.doubleToString(2.0/sqrt(M_PI));
replacements["DO_LJPME"] = doLJPME ? "1" : "0"; replacements["DO_LJPME"] = doLJPME ? "1" : "0";
...@@ -1059,7 +718,7 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon ...@@ -1059,7 +718,7 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
baseExceptionParams.upload(baseExceptionParamsVec); baseExceptionParams.upload(baseExceptionParamsVec);
map<string, string> replacements; map<string, string> replacements;
replacements["APPLY_PERIODIC"] = (usePeriodic && force.getExceptionsUsePeriodicBoundaryConditions() ? "1" : "0"); replacements["APPLY_PERIODIC"] = (usePeriodic && force.getExceptionsUsePeriodicBoundaryConditions() ? "1" : "0");
replacements["PARAMS"] = cu.getBondedUtilities().addArgument(exceptionParams.getDevicePointer(), "float4"); replacements["PARAMS"] = cu.getBondedUtilities().addArgument(exceptionParams, "float4");
if (force.getIncludeDirectSpace()) if (force.getIncludeDirectSpace())
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CommonKernelSources::nonbondedExceptions, replacements), force.getForceGroup()); cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CommonKernelSources::nonbondedExceptions, replacements), force.getForceGroup());
} }
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2011-2022 Stanford University and the Authors. * * Portions copyright (c) 2023 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -27,133 +27,19 @@ ...@@ -27,133 +27,19 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>. * * along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "OpenCLArray.h"
#include "openmm/System.h"
#include "openmm/common/BondedUtilities.h" #include "openmm/common/BondedUtilities.h"
#include <string>
#include <vector>
namespace OpenMM { namespace OpenMM {
class OpenCLContext;
/** /**
* This class provides a generic mechanism for evaluating bonded interactions. You write only * This class exists only for backward compatibility. It adds no features beyond
* the source code needed to compute one interaction, and this class takes care of creating * the base BondedUtilities class.
* and executing a complete kernel that loops over bonds, evaluates each one, and accumulates
* the resulting forces and energies. This offers two advantages. First, it simplifies the
* task of writing a new Force. Second, it allows multiple forces to be evaluated by a single
* kernel, which reduces overhead and improves performance.
*
* A "bonded interaction" means an interaction that affects a small, fixed set of particles.
* The interaction energy may depend on the positions of only those particles, and the list of
* particles forming a "bond" may not change with time. Examples of bonded interactions
* include HarmonicBondForce, HarmonicAngleForce, and PeriodicTorsionForce.
*
* To create a bonded interaction, call addInteraction(). You pass to it a block of source
* code for evaluating the interaction. The inputs and outputs for that source code are as
* follows:
*
* <ol>
* <li>The index of the bond being evaluated will have been stored in the unsigned int variable "index".</li>
* <li>The indices of the atoms forming that bond will have been stored in the unsigned int variables "atom1",
* "atom2", ....</li>
* <li>The positions of those atoms will have been stored in the real4 variables "pos1", "pos2", ....</li>
* <li>A real variable called "energy" will exist. Your code should add the potential energy of the
* bond to that variable.</li>
* <li>Your code should define real4 variables called "force1", "force2", ... that contain the force to
* apply to each atom.</li>
* </ol>
*
* As a simple example, the following source code would be used to implement a pairwise interaction of
* the form E=r^2:
*
* \verbatim embed:rst:leading-asterisk
* .. code-block:: cpp
*
* real4 delta = pos2-pos1;
* energy += delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
* real4 force1 = 2.0f*delta;
* real4 force2 = -2.0f*delta;
*
* \endverbatim
*
* Interactions will often depend on parameters or other data. Call addArgument() to provide the data
* to this class. It will be passed to the interaction kernel as an argument, and you can refer to it
* from your interaction code.
*/ */
class OPENMM_EXPORT_COMMON OpenCLBondedUtilities : public BondedUtilities { class OPENMM_EXPORT_COMMON OpenCLBondedUtilities : public BondedUtilities {
public: public:
OpenCLBondedUtilities(OpenCLContext& context); OpenCLBondedUtilities(ComputeContext& context) : BondedUtilities(context) {
/** }
* Add a bonded interaction.
*
* @param atoms this should have one entry for each bond, and that entry should contain the list
* of atoms involved in the bond. Every entry must have the same number of atoms.
* @param source the code to evaluate the interaction
* @param group the force group in which the interaction should be calculated
*/
void addInteraction(const std::vector<std::vector<int> >& atoms, const std::string& source, int group);
/**
* Add an argument that should be passed to the interaction kernel.
*
* @param data the device memory containing the data to pass
* @param type the data type contained in the memory (e.g. "float4")
* @return the name that will be used for the argument. Any code you pass to addInteraction() should
* refer to it by this name.
*/
std::string addArgument(cl::Memory& data, const std::string& type);
/**
* Add an argument that should be passed to the interaction kernel.
*
* @param data the array containing the data to pass
* @param type the data type contained in the memory (e.g. "float4")
* @return the name that will be used for the argument. Any code you pass to addInteraction() should
* refer to it by this name.
*/
std::string addArgument(ArrayInterface& data, const std::string& type);
/**
* Register that the interaction kernel will be computing the derivative of the potential energy
* with respect to a parameter.
*
* @param param the name of the parameter
* @return the variable that will be used to accumulate the derivative. Any code you pass to addInteraction() should
* add its contributions to this variable.
*/
std::string addEnergyParameterDerivative(const std::string& param);
/**
* Add some OpenCL code that should be included in the program, before the start of the kernel.
* This can be used, for example, to define functions that will be called by the kernel.
*
* @param source the code to include
*/
void addPrefixCode(const std::string& source);
/**
* Initialize this object in preparation for a simulation.
*/
void initialize(const System& system);
/**
* Compute the bonded interactions.
*
* @param groups a set of bit flags for which force groups to include
*/
void computeInteractions(int groups);
private:
std::string createForceSource(int forceIndex, int numBonds, int numAtoms, int group, const std::string& computeForce);
OpenCLContext& context;
cl::Kernel kernel;
std::vector<std::vector<std::vector<int> > > forceAtoms;
std::vector<int> indexWidth;
std::vector<std::string> forceSource;
std::vector<int> forceGroup;
std::vector<cl::Memory*> arguments;
std::vector<std::string> argTypes;
std::vector<OpenCLArray> atomIndices;
std::vector<std::string> prefixCode;
std::vector<std::string> energyParameterDerivatives;
int maxBonds, allGroups;
bool hasInitializedKernels;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. * * Portions copyright (c) 2009-2023 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -191,6 +191,12 @@ public: ...@@ -191,6 +191,12 @@ public:
int getContextIndex() const { int getContextIndex() const {
return contextIndex; return contextIndex;
} }
/**
* Get a list of all contexts being used for the current simulation.
* This is relevant when a simulation is parallelized across multiple devices. In that case,
* one ComputeContext is created for each device.
*/
std::vector<ComputeContext*> getAllContexts();
/** /**
* Get the cl::CommandQueue currently being used for execution. * Get the cl::CommandQueue currently being used for execution.
*/ */
......
...@@ -82,121 +82,6 @@ private: ...@@ -82,121 +82,6 @@ private:
OpenCLContext& cl; OpenCLContext& cl;
}; };
/**
* This kernel provides methods for setting and retrieving various state data: time, positions,
* velocities, and forces.
*/
class OpenCLUpdateStateDataKernel : public UpdateStateDataKernel {
public:
OpenCLUpdateStateDataKernel(std::string name, const Platform& platform, OpenCLContext& cl) : UpdateStateDataKernel(name, platform), cl(cl) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
*/
void initialize(const System& system);
/**
* Get the current time (in picoseconds).
*
* @param context the context in which to execute this kernel
*/
double getTime(const ContextImpl& context) const;
/**
* Set the current time (in picoseconds).
*
* @param context the context in which to execute this kernel
*/
void setTime(ContextImpl& context, double time);
/**
* Get the current step count
*
* @param context the context in which to execute this kernel
*/
long long getStepCount(const ContextImpl& context) const;
/**
* Set the current step count
*
* @param context the context in which to execute this kernel
*/
void setStepCount(const ContextImpl& context, long long count);
/**
* Get the positions of all particles.
*
* @param positions on exit, this contains the particle positions
*/
void getPositions(ContextImpl& context, std::vector<Vec3>& positions);
/**
* Set the positions of all particles.
*
* @param positions a vector containg the particle positions
*/
void setPositions(ContextImpl& context, const std::vector<Vec3>& positions);
/**
* Get the velocities of all particles.
*
* @param velocities on exit, this contains the particle velocities
*/
void getVelocities(ContextImpl& context, std::vector<Vec3>& velocities);
/**
* Set the velocities of all particles.
*
* @param velocities a vector containg the particle velocities
*/
void setVelocities(ContextImpl& context, const std::vector<Vec3>& velocities);
/**
* Compute velocities, shifted in time to account for a leapfrog integrator. The shift
* is based on the most recently computed forces.
*
* @param context the context in which to execute this kernel
* @param timeShift the amount by which to shift the velocities in time
* @param velocities the shifted velocities are returned in this
*/
void computeShiftedVelocities(ContextImpl& context, double timeShift, std::vector<Vec3>& velocities);
/**
* Get the current forces on all particles.
*
* @param forces on exit, this contains the forces
*/
void getForces(ContextImpl& context, std::vector<Vec3>& forces);
/**
* Get the current derivatives of the energy with respect to context parameters.
*
* @param derivs on exit, this contains the derivatives
*/
void getEnergyParameterDerivatives(ContextImpl& context, std::map<std::string, double>& derivs);
/**
* Get the current periodic box vectors.
*
* @param a on exit, this contains the vector defining the first edge of the periodic box
* @param b on exit, this contains the vector defining the second edge of the periodic box
* @param c on exit, this contains the vector defining the third edge of the periodic box
*/
void getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const;
/**
* Set the current periodic box vectors.
*
* @param a the vector defining the first edge of the periodic box
* @param b the vector defining the second edge of the periodic box
* @param c the vector defining the third edge of the periodic box
*/
void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c);
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private:
OpenCLContext& cl;
};
/** /**
* This kernel is invoked by NonbondedForce to calculate the forces acting on the system. * This kernel is invoked by NonbondedForce to calculate the forces acting on the system.
*/ */
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2011-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "OpenCLBondedUtilities.h"
#include "OpenCLContext.h"
#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "OpenCLNonbondedUtilities.h"
#include <iostream>
using namespace OpenMM;
using namespace std;
OpenCLBondedUtilities::OpenCLBondedUtilities(OpenCLContext& context) : context(context), maxBonds(0), allGroups(0), hasInitializedKernels(false) {
}
void OpenCLBondedUtilities::addInteraction(const vector<vector<int> >& atoms, const string& source, int group) {
if (atoms.size() > 0) {
forceAtoms.push_back(atoms);
forceSource.push_back(source);
forceGroup.push_back(group);
allGroups |= 1<<group;
int width = 1;
while (width < (int) atoms[0].size())
width *= 2;
indexWidth.push_back(width);
}
}
string OpenCLBondedUtilities::addArgument(cl::Memory& data, const string& type) {
arguments.push_back(&data);
argTypes.push_back(type);
return "customArg"+context.intToString(arguments.size());
}
string OpenCLBondedUtilities::addArgument(ArrayInterface& data, const string& type) {
return addArgument(context.unwrap(data).getDeviceBuffer(), type);
}
string OpenCLBondedUtilities::addEnergyParameterDerivative(const string& param) {
// See if the parameter has already been added.
int index;
for (index = 0; index < energyParameterDerivatives.size(); index++)
if (param == energyParameterDerivatives[index])
break;
if (index == energyParameterDerivatives.size())
energyParameterDerivatives.push_back(param);
context.addEnergyParameterDerivative(param);
return string("energyParamDeriv")+context.intToString(index);
}
void OpenCLBondedUtilities::addPrefixCode(const string& source) {
for (int i = 0; i < (int) prefixCode.size(); i++)
if (prefixCode[i] == source)
return;
prefixCode.push_back(source);
}
void OpenCLBondedUtilities::initialize(const System& system) {
int numForces = forceAtoms.size();
if (numForces == 0)
return;
// Build the lists of atom indices.
atomIndices.resize(numForces);
for (int i = 0; i < numForces; i++) {
int numBonds = forceAtoms[i].size();
int numAtoms = forceAtoms[i][0].size();
int width = indexWidth[i];
vector<cl_uint> indexVec(width*numBonds);
for (int bond = 0; bond < numBonds; bond++) {
for (int atom = 0; atom < numAtoms; atom++)
indexVec[bond*width+atom] = forceAtoms[i][bond][atom];
}
atomIndices[i].initialize<cl_uint>(context, indexVec.size(), "bondedIndices");
atomIndices[i].upload(indexVec);
}
// Create the kernel.
stringstream s;
for (int i = 0; i < (int) prefixCode.size(); i++)
s<<prefixCode[i];
s<<"__kernel void computeBondedForces(__global unsigned long* restrict forceBuffers, __global mixed* restrict energyBuffer, __global const real4* restrict posq, int groups, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ";
for (int force = 0; force < numForces; force++) {
string indexType = "uint"+(indexWidth[force] == 1 ? "" : context.intToString(indexWidth[force]));
s<<", __global const "<<indexType<<"* restrict atomIndices"<<force;
}
for (int i = 0; i < (int) arguments.size(); i++)
s<<", __global "<<argTypes[i]<<"* customArg"<<(i+1);
if (energyParameterDerivatives.size() > 0)
s<<", __global mixed* restrict energyParamDerivs";
s<<") {\n";
s<<"mixed energy = 0;\n";
for (int i = 0; i < energyParameterDerivatives.size(); i++)
s<<"mixed energyParamDeriv"<<i<<" = 0;\n";
for (int force = 0; force < numForces; force++)
s<<createForceSource(force, forceAtoms[force].size(), forceAtoms[force][0].size(), forceGroup[force], forceSource[force]);
s<<"energyBuffer[get_global_id(0)] += energy;\n";
const vector<string>& allParamDerivNames = context.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < energyParameterDerivatives.size(); i++)
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == energyParameterDerivatives[i])
s<<"energyParamDerivs[get_global_id(0)*"<<numDerivs<<"+"<<index<<"] += energyParamDeriv"<<i<<";\n";
s<<"}\n";
map<string, string> defines;
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
cl::Program program = context.createProgram(s.str(), defines);
kernel = cl::Kernel(program, "computeBondedForces");
forceAtoms.clear();
forceSource.clear();
}
string OpenCLBondedUtilities::createForceSource(int forceIndex, int numBonds, int numAtoms, int group, const string& computeForce) {
maxBonds = max(maxBonds, numBonds);
int width = 1;
while (width < numAtoms)
width *= 2;
string suffix1[] = {""};
string suffix4[] = {".x", ".y", ".z", ".w"};
string suffix16[] = {".s0", ".s1", ".s2", ".s3", ".s4", ".s5", ".s6", ".s7",
".s8", ".s9", ".s10", ".s11", ".s12", ".s13", ".s14", ".s15"};
string* suffix;
if (width == 1)
suffix = suffix1;
else if (width <= 4)
suffix = suffix4;
else
suffix = suffix16;
string indexType = "uint"+(width == 1 ? "" : context.intToString(width));
stringstream s;
s<<"if ((groups&"<<(1<<group)<<") != 0)\n";
s<<"for (unsigned int index = get_global_id(0); index < "<<numBonds<<"; index += get_global_size(0)) {\n";
s<<" "<<indexType<<" atoms = atomIndices"<<forceIndex<<"[index];\n";
for (int i = 0; i < numAtoms; i++) {
s<<" unsigned int atom"<<(i+1)<<" = atoms"<<suffix[i]<<";\n";
s<<" real4 pos"<<(i+1)<<" = posq[atom"<<(i+1)<<"];\n";
}
s<<computeForce<<"\n";
for (int i = 0; i < numAtoms; i++) {
s<<" {\n";
s<<" ATOMIC_ADD(&forceBuffers[atom"<<(i+1)<<"], (mm_ulong) realToFixedPoint(force"<<(i+1)<<".x));\n";
s<<" ATOMIC_ADD(&forceBuffers[atom"<<(i+1)<<"+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(force"<<(i+1)<<".y));\n";
s<<" ATOMIC_ADD(&forceBuffers[atom"<<(i+1)<<"+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(force"<<(i+1)<<".z));\n";
s<<" }\n";
}
s<<"}\n";
return s.str();
}
void OpenCLBondedUtilities::computeInteractions(int groups) {
if ((groups&allGroups) == 0)
return;
if (!hasInitializedKernels) {
hasInitializedKernels = true;
int index = 0;
kernel.setArg<cl::Buffer>(index++, context.getLongForceBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
index += 6;
for (int j = 0; j < (int) atomIndices.size(); j++)
kernel.setArg<cl::Buffer>(index++, atomIndices[j].getDeviceBuffer());
for (int j = 0; j < (int) arguments.size(); j++)
kernel.setArg<cl::Memory>(index++, *arguments[j]);
if (energyParameterDerivatives.size() > 0)
kernel.setArg<cl::Memory>(index++, context.getEnergyParamDerivBuffer().getDeviceBuffer());
}
kernel.setArg<cl_int>(3, groups);
if (context.getUseDoublePrecision()) {
kernel.setArg<mm_double4>(4, context.getPeriodicBoxSizeDouble());
kernel.setArg<mm_double4>(5, context.getInvPeriodicBoxSizeDouble());
kernel.setArg<mm_double4>(6, context.getPeriodicBoxVecXDouble());
kernel.setArg<mm_double4>(7, context.getPeriodicBoxVecYDouble());
kernel.setArg<mm_double4>(8, context.getPeriodicBoxVecZDouble());
}
else {
kernel.setArg<mm_float4>(4, context.getPeriodicBoxSize());
kernel.setArg<mm_float4>(5, context.getInvPeriodicBoxSize());
kernel.setArg<mm_float4>(6, context.getPeriodicBoxVecX());
kernel.setArg<mm_float4>(7, context.getPeriodicBoxVecY());
kernel.setArg<mm_float4>(8, context.getPeriodicBoxVecZ());
}
context.executeKernel(kernel, maxBonds);
}
...@@ -647,6 +647,13 @@ cl::Program OpenCLContext::createProgram(const string source, const map<string, ...@@ -647,6 +647,13 @@ cl::Program OpenCLContext::createProgram(const string source, const map<string,
return program; return program;
} }
vector<ComputeContext*> OpenCLContext::getAllContexts() {
vector<ComputeContext*> result;
for (OpenCLContext* c : platformData.contexts)
result.push_back(c);
return result;
}
cl::CommandQueue& OpenCLContext::getQueue() { cl::CommandQueue& OpenCLContext::getQueue() {
return currentQueue; return currentQueue;
} }
......
...@@ -70,7 +70,7 @@ KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platfo ...@@ -70,7 +70,7 @@ KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platfo
if (name == CalcForcesAndEnergyKernel::Name()) if (name == CalcForcesAndEnergyKernel::Name())
return new OpenCLCalcForcesAndEnergyKernel(name, platform, cl); return new OpenCLCalcForcesAndEnergyKernel(name, platform, cl);
if (name == UpdateStateDataKernel::Name()) if (name == UpdateStateDataKernel::Name())
return new OpenCLUpdateStateDataKernel(name, platform, cl); return new CommonUpdateStateDataKernel(name, platform, cl);
if (name == ApplyConstraintsKernel::Name()) if (name == ApplyConstraintsKernel::Name())
return new CommonApplyConstraintsKernel(name, platform, cl); return new CommonApplyConstraintsKernel(name, platform, cl);
if (name == VirtualSitesKernel::Name()) if (name == VirtualSitesKernel::Name())
......
...@@ -101,347 +101,6 @@ double OpenCLCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, ...@@ -101,347 +101,6 @@ double OpenCLCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context,
return sum; return sum;
} }
void OpenCLUpdateStateDataKernel::initialize(const System& system) {
}
double OpenCLUpdateStateDataKernel::getTime(const ContextImpl& context) const {
return cl.getTime();
}
void OpenCLUpdateStateDataKernel::setTime(ContextImpl& context, double time) {
vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
for (auto ctx : contexts)
ctx->setTime(time);
}
long long OpenCLUpdateStateDataKernel::getStepCount(const ContextImpl& context) const {
return cl.getStepCount();
}
void OpenCLUpdateStateDataKernel::setStepCount(const ContextImpl& context, long long count) {
vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
for (auto ctx : contexts)
ctx->setStepCount(count);
}
void OpenCLUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) {
int numParticles = context.getSystem().getNumParticles();
positions.resize(numParticles);
vector<mm_float4> posCorrection;
if (cl.getUseDoublePrecision()) {
mm_double4* posq = (mm_double4*) cl.getPinnedBuffer();
cl.getPosq().download(posq);
}
else if (cl.getUseMixedPrecision()) {
mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
cl.getPosq().download(posq, false);
posCorrection.resize(numParticles);
cl.getPosqCorrection().download(posCorrection);
}
else {
mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
cl.getPosq().download(posq);
}
// Filling in the output array is done in parallel for speed.
cl.getPlatformData().threads.execute([&] (ThreadPool& threads, int threadIndex) {
// Compute the position of each particle to return to the user. This is done in parallel for speed.
const vector<int>& order = cl.getAtomIndex();
int numParticles = cl.getNumAtoms();
Vec3 boxVectors[3];
cl.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
int numThreads = threads.getNumThreads();
int start = threadIndex*numParticles/numThreads;
int end = (threadIndex+1)*numParticles/numThreads;
if (cl.getUseDoublePrecision()) {
mm_double4* posq = (mm_double4*) cl.getPinnedBuffer();
for (int i = start; i < end; ++i) {
mm_double4 pos = posq[i];
mm_int4 offset = cl.getPosCellOffsets()[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
else if (cl.getUseMixedPrecision()) {
mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
for (int i = start; i < end; ++i) {
mm_float4 pos1 = posq[i];
mm_float4 pos2 = posCorrection[i];
mm_int4 offset = cl.getPosCellOffsets()[i];
positions[order[i]] = Vec3((double)pos1.x+(double)pos2.x, (double)pos1.y+(double)pos2.y, (double)pos1.z+(double)pos2.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
else {
mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
for (int i = start; i < end; ++i) {
mm_float4 pos = posq[i];
mm_int4 offset = cl.getPosCellOffsets()[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
});
cl.getPlatformData().threads.waitForThreads();
}
void OpenCLUpdateStateDataKernel::setPositions(ContextImpl& context, const vector<Vec3>& positions) {
const vector<cl_int>& order = cl.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
if (cl.getUseDoublePrecision()) {
mm_double4* posq = (mm_double4*) cl.getPinnedBuffer();
cl.getPosq().download(posq);
for (int i = 0; i < numParticles; ++i) {
mm_double4& pos = posq[i];
const Vec3& p = positions[order[i]];
pos.x = p[0];
pos.y = p[1];
pos.z = p[2];
}
for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
posq[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
cl.getPosq().upload(posq);
}
else {
mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
cl.getPosq().download(posq);
for (int i = 0; i < numParticles; ++i) {
mm_float4& pos = posq[i];
const Vec3& p = positions[order[i]];
pos.x = (cl_float) p[0];
pos.y = (cl_float) p[1];
pos.z = (cl_float) p[2];
}
for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
posq[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
cl.getPosq().upload(posq);
}
if (cl.getUseMixedPrecision()) {
mm_float4* posCorrection = (mm_float4*) cl.getPinnedBuffer();
for (int i = 0; i < numParticles; ++i) {
mm_float4& c = posCorrection[i];
const Vec3& p = positions[order[i]];
c.x = (cl_float) (p[0]-(cl_float)p[0]);
c.y = (cl_float) (p[1]-(cl_float)p[1]);
c.z = (cl_float) (p[2]-(cl_float)p[2]);
c.w = 0;
}
for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
posCorrection[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
cl.getPosqCorrection().upload(posCorrection);
}
for (auto& offset : cl.getPosCellOffsets())
offset = mm_int4(0, 0, 0, 0);
cl.reorderAtoms();
}
void OpenCLUpdateStateDataKernel::getVelocities(ContextImpl& context, vector<Vec3>& velocities) {
const vector<cl_int>& order = cl.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
velocities.resize(numParticles);
if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
mm_double4* velm = (mm_double4*) cl.getPinnedBuffer();
cl.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
mm_double4 vel = velm[i];
velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
}
}
else {
mm_float4* velm = (mm_float4*) cl.getPinnedBuffer();
cl.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
mm_float4 vel = velm[i];
velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
}
}
}
void OpenCLUpdateStateDataKernel::setVelocities(ContextImpl& context, const vector<Vec3>& velocities) {
const vector<cl_int>& order = cl.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
mm_double4* velm = (mm_double4*) cl.getPinnedBuffer();
cl.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
mm_double4& vel = velm[i];
const Vec3& p = velocities[order[i]];
vel.x = p[0];
vel.y = p[1];
vel.z = p[2];
}
for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
velm[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
cl.getVelm().upload(velm);
}
else {
mm_float4* velm = (mm_float4*) cl.getPinnedBuffer();
cl.getVelm().download(velm);
for (int i = 0; i < numParticles; ++i) {
mm_float4& vel = velm[i];
const Vec3& p = velocities[order[i]];
vel.x = p[0];
vel.y = p[1];
vel.z = p[2];
}
for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
velm[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
cl.getVelm().upload(velm);
}
}
void OpenCLUpdateStateDataKernel::computeShiftedVelocities(ContextImpl& context, double timeShift, vector<Vec3>& velocities) {
cl.getIntegrationUtilities().computeShiftedVelocities(timeShift, velocities);
}
void OpenCLUpdateStateDataKernel::getForces(ContextImpl& context, vector<Vec3>& forces) {
const vector<cl_int>& order = cl.getAtomIndex();
int numParticles = context.getSystem().getNumParticles();
forces.resize(numParticles);
if (cl.getUseDoublePrecision()) {
mm_double4* force = (mm_double4*) cl.getPinnedBuffer();
cl.getForce().download(force);
for (int i = 0; i < numParticles; ++i) {
mm_double4 f = force[i];
forces[order[i]] = Vec3(f.x, f.y, f.z);
}
}
else {
mm_float4* force = (mm_float4*) cl.getPinnedBuffer();
cl.getForce().download(force);
for (int i = 0; i < numParticles; ++i) {
mm_float4 f = force[i];
forces[order[i]] = Vec3(f.x, f.y, f.z);
}
}
}
void OpenCLUpdateStateDataKernel::getEnergyParameterDerivatives(ContextImpl& context, map<string, double>& derivs) {
const vector<string>& paramDerivNames = cl.getEnergyParamDerivNames();
int numDerivs = paramDerivNames.size();
if (numDerivs == 0)
return;
derivs = cl.getEnergyParamDerivWorkspace();
OpenCLArray& derivArray = cl.getEnergyParamDerivBuffer();
if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
vector<double> derivBuffers;
derivArray.download(derivBuffers);
for (int i = numDerivs; i < derivArray.getSize(); i += numDerivs)
for (int j = 0; j < numDerivs; j++)
derivBuffers[j] += derivBuffers[i+j];
for (int i = 0; i < numDerivs; i++)
derivs[paramDerivNames[i]] += derivBuffers[i];
}
else {
vector<float> derivBuffers;
derivArray.download(derivBuffers);
for (int i = numDerivs; i < derivArray.getSize(); i += numDerivs)
for (int j = 0; j < numDerivs; j++)
derivBuffers[j] += derivBuffers[i+j];
for (int i = 0; i < numDerivs; i++)
derivs[paramDerivNames[i]] += derivBuffers[i];
}
}
void OpenCLUpdateStateDataKernel::getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const {
cl.getPeriodicBoxVectors(a, b, c);
}
void OpenCLUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) {
vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
// If any particles have been wrapped to the first periodic box, we need to unwrap them
// to avoid changing their positions.
vector<Vec3> positions;
for (auto offset : cl.getPosCellOffsets()) {
if (offset.x != 0 || offset.y != 0 || offset.z != 0) {
getPositions(context, positions);
break;
}
}
// Update the vectors.
for (auto ctx : contexts)
ctx->setPeriodicBoxVectors(a, b, c);
if (positions.size() > 0)
setPositions(context, positions);
}
void OpenCLUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
int version = 3;
stream.write((char*) &version, sizeof(int));
int precision = (cl.getUseDoublePrecision() ? 2 : cl.getUseMixedPrecision() ? 1 : 0);
stream.write((char*) &precision, sizeof(int));
double time = cl.getTime();
stream.write((char*) &time, sizeof(double));
long long stepCount = cl.getStepCount();
stream.write((char*) &stepCount, sizeof(long long));
int stepsSinceReorder = cl.getStepsSinceReorder();
stream.write((char*) &stepsSinceReorder, sizeof(int));
char* buffer = (char*) cl.getPinnedBuffer();
cl.getPosq().download(buffer);
stream.write(buffer, cl.getPosq().getSize()*cl.getPosq().getElementSize());
if (cl.getUseMixedPrecision()) {
cl.getPosqCorrection().download(buffer);
stream.write(buffer, cl.getPosqCorrection().getSize()*cl.getPosqCorrection().getElementSize());
}
cl.getVelm().download(buffer);
stream.write(buffer, cl.getVelm().getSize()*cl.getVelm().getElementSize());
stream.write((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().size());
stream.write((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
Vec3 boxVectors[3];
cl.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
stream.write((char*) boxVectors, 3*sizeof(Vec3));
cl.getIntegrationUtilities().createCheckpoint(stream);
SimTKOpenMMUtilities::createCheckpoint(stream);
}
void OpenCLUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
int version;
stream.read((char*) &version, sizeof(int));
if (version != 3)
throw OpenMMException("Checkpoint was created with a different version of OpenMM");
int precision;
stream.read((char*) &precision, sizeof(int));
int expectedPrecision = (cl.getUseDoublePrecision() ? 2 : cl.getUseMixedPrecision() ? 1 : 0);
if (precision != expectedPrecision)
throw OpenMMException("Checkpoint was created with a different numeric precision");
double time;
stream.read((char*) &time, sizeof(double));
long long stepCount;
stream.read((char*) &stepCount, sizeof(long long));
int stepsSinceReorder;
stream.read((char*) &stepsSinceReorder, sizeof(int));
vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
for (auto ctx : contexts) {
ctx->setTime(time);
ctx->setStepCount(stepCount);
ctx->setStepsSinceReorder(stepsSinceReorder);
}
char* buffer = (char*) cl.getPinnedBuffer();
stream.read(buffer, cl.getPosq().getSize()*cl.getPosq().getElementSize());
cl.getPosq().upload(buffer);
if (cl.getUseMixedPrecision()) {
stream.read(buffer, cl.getPosqCorrection().getSize()*cl.getPosqCorrection().getElementSize());
cl.getPosqCorrection().upload(buffer);
}
stream.read(buffer, cl.getVelm().getSize()*cl.getVelm().getElementSize());
cl.getVelm().upload(buffer);
stream.read((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().size());
cl.getAtomIndexArray().upload(cl.getAtomIndex());
stream.read((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
Vec3 boxVectors[3];
stream.read((char*) &boxVectors, 3*sizeof(Vec3));
for (auto ctx : contexts)
ctx->setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
cl.getIntegrationUtilities().loadCheckpoint(stream);
SimTKOpenMMUtilities::loadCheckpoint(stream);
for (auto listener : cl.getReorderListeners())
listener->execute();
cl.validateAtomOrder();
}
class OpenCLCalcNonbondedForceKernel::ForceInfo : public OpenCLForceInfo { class OpenCLCalcNonbondedForceKernel::ForceInfo : public OpenCLForceInfo {
public: public:
ForceInfo(int requiredBuffers, const NonbondedForce& force) : OpenCLForceInfo(requiredBuffers), force(force) { ForceInfo(int requiredBuffers, const NonbondedForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
...@@ -936,7 +595,7 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb ...@@ -936,7 +595,7 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
} }
exclusionAtoms.upload(exclusionAtomsVec); exclusionAtoms.upload(exclusionAtomsVec);
map<string, string> replacements; map<string, string> replacements;
replacements["PARAMS"] = cl.getBondedUtilities().addArgument(exclusionParams.getDeviceBuffer(), "float4"); replacements["PARAMS"] = cl.getBondedUtilities().addArgument(exclusionParams, "float4");
replacements["EWALD_ALPHA"] = cl.doubleToString(alpha); replacements["EWALD_ALPHA"] = cl.doubleToString(alpha);
replacements["TWO_OVER_SQRT_PI"] = cl.doubleToString(2.0/sqrt(M_PI)); replacements["TWO_OVER_SQRT_PI"] = cl.doubleToString(2.0/sqrt(M_PI));
replacements["DO_LJPME"] = doLJPME ? "1" : "0"; replacements["DO_LJPME"] = doLJPME ? "1" : "0";
...@@ -998,7 +657,7 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb ...@@ -998,7 +657,7 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
baseExceptionParams.upload(baseExceptionParamsVec); baseExceptionParams.upload(baseExceptionParamsVec);
map<string, string> replacements; map<string, string> replacements;
replacements["APPLY_PERIODIC"] = (usePeriodic && force.getExceptionsUsePeriodicBoundaryConditions() ? "1" : "0"); replacements["APPLY_PERIODIC"] = (usePeriodic && force.getExceptionsUsePeriodicBoundaryConditions() ? "1" : "0");
replacements["PARAMS"] = cl.getBondedUtilities().addArgument(exceptionParams.getDeviceBuffer(), "float4"); replacements["PARAMS"] = cl.getBondedUtilities().addArgument(exceptionParams, "float4");
if (force.getIncludeDirectSpace()) if (force.getIncludeDirectSpace())
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(CommonKernelSources::nonbondedExceptions, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(CommonKernelSources::nonbondedExceptions, replacements), force.getForceGroup());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment