Unverified Commit d47ea1de authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Unified storage of global parameters (#5002)

* Unified storage of global parameters

* Fixes to CUDA and HIP

* Store global parameters as real instead of float
parent 6c119bc2
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
private: private:
class ForceInfo; class ForceInfo;
double cutoff; double cutoff;
bool hasInitializedKernels, needParameterGradient, needEnergyParamDerivs; bool hasInitializedKernels, needGlobalParams, needParameterGradient, needEnergyParamDerivs;
int maxTiles, numComputedValues; int maxTiles, numComputedValues;
ComputeContext& cc; ComputeContext& cc;
ForceInfo* info; ForceInfo* info;
...@@ -83,9 +83,7 @@ private: ...@@ -83,9 +83,7 @@ private:
ComputeParameterSet* energyDerivChain; ComputeParameterSet* energyDerivChain;
std::vector<ComputeParameterSet*> dValuedParam; std::vector<ComputeParameterSet*> dValuedParam;
std::vector<ComputeArray> dValue0dParam; std::vector<ComputeArray> dValue0dParam;
ComputeArray longEnergyDerivs, globals, valueBuffers; ComputeArray longEnergyDerivs, valueBuffers;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, int> tabulatedFunctionUpdateCount; std::map<std::string, int> tabulatedFunctionUpdateCount;
std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue; std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue;
......
...@@ -73,19 +73,16 @@ public: ...@@ -73,19 +73,16 @@ public:
private: private:
class ForceInfo; class ForceInfo;
int numDonors, numAcceptors; int numDonors, numAcceptors;
bool hasInitializedKernel, useBoundingBoxes; bool hasInitializedKernel, useBoundingBoxes, needGlobalParams;
ComputeContext& cc; ComputeContext& cc;
ForceInfo* info; ForceInfo* info;
ComputeParameterSet* donorParams; ComputeParameterSet* donorParams;
ComputeParameterSet* acceptorParams; ComputeParameterSet* acceptorParams;
ComputeArray globals;
ComputeArray donors; ComputeArray donors;
ComputeArray acceptors; ComputeArray acceptors;
ComputeArray donorExclusions; ComputeArray donorExclusions;
ComputeArray acceptorExclusions; ComputeArray acceptorExclusions;
ComputeArray donorBlockCenter, donorBlockSize, acceptorBlockCenter, acceptorBlockSize; ComputeArray donorBlockCenter, donorBlockSize, acceptorBlockCenter, acceptorBlockSize;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, int> tabulatedFunctionUpdateCount; std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system; const System& system;
......
...@@ -76,15 +76,13 @@ private: ...@@ -76,15 +76,13 @@ private:
class ForceInfo; class ForceInfo;
ComputeContext& cc; ComputeContext& cc;
ForceInfo* info; ForceInfo* info;
bool hasInitializedKernel; bool hasInitializedKernel, needGlobalParams;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
int maxNeighborPairs, forceWorkgroupSize, findNeighborsWorkgroupSize; int maxNeighborPairs, forceWorkgroupSize, findNeighborsWorkgroupSize;
ComputeParameterSet* params; ComputeParameterSet* params;
ComputeArray globals, particleTypes, orderIndex, particleOrder; ComputeArray particleTypes, orderIndex, particleOrder;
ComputeArray exclusions, exclusionStartIndex, blockCenter, blockBoundingBox; ComputeArray exclusions, exclusionStartIndex, blockCenter, blockBoundingBox;
ComputeArray neighborPairs, numNeighborPairs, neighborStartIndex, numNeighborsForAtom, neighbors; ComputeArray neighborPairs, numNeighborPairs, neighborStartIndex, numNeighborsForAtom, neighbors;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, int> tabulatedFunctionUpdateCount; std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system; const System& system;
......
...@@ -82,7 +82,7 @@ private: ...@@ -82,7 +82,7 @@ private:
ForceInfo* info; ForceInfo* info;
ComputeParameterSet* params; ComputeParameterSet* params;
ComputeParameterSet* computedValues; ComputeParameterSet* computedValues;
ComputeArray globals, interactionGroupData, filteredGroupData, numGroupTiles; ComputeArray interactionGroupData, filteredGroupData, numGroupTiles;
ComputeKernel interactionGroupKernel, prepareNeighborListKernel, buildNeighborListKernel, computedValuesKernel; ComputeKernel interactionGroupKernel, prepareNeighborListKernel, buildNeighborListKernel, computedValuesKernel;
std::vector<void*> interactionGroupArgs; std::vector<void*> interactionGroupArgs;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
...@@ -93,7 +93,7 @@ private: ...@@ -93,7 +93,7 @@ private:
std::vector<ComputeParameterInfo> paramBuffers, computedValueBuffers; std::vector<ComputeParameterInfo> paramBuffers, computedValueBuffers;
double longRangeCoefficient; double longRangeCoefficient;
std::vector<double> longRangeCoefficientDerivs; std::vector<double> longRangeCoefficientDerivs;
bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList; bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList, needGlobalParams;
int numGroupThreadBlocks; int numGroupThreadBlocks;
CustomNonbondedForce* forceCopy; CustomNonbondedForce* forceCopy;
CustomNonbondedForceImpl::LongRangeCorrectionData longRangeCorrectionData; CustomNonbondedForceImpl::LongRangeCorrectionData longRangeCorrectionData;
......
...@@ -132,7 +132,6 @@ private: ...@@ -132,7 +132,6 @@ private:
ComputeArray exceptionParamOffsets; ComputeArray exceptionParamOffsets;
ComputeArray particleOffsetIndices; ComputeArray particleOffsetIndices;
ComputeArray exceptionOffsetIndices; ComputeArray exceptionOffsetIndices;
ComputeArray globalParams;
ComputeArray cosSinSums; ComputeArray cosSinSums;
ComputeArray pmeGrid1; ComputeArray pmeGrid1;
ComputeArray pmeGrid2; ComputeArray pmeGrid2;
...@@ -163,7 +162,8 @@ private: ...@@ -163,7 +162,8 @@ private:
std::map<std::string, std::string> pmeDefines; std::map<std::string, std::string> pmeDefines;
std::vector<std::pair<int, int> > exceptionAtoms; std::vector<std::pair<int, int> > exceptionAtoms;
std::vector<std::string> paramNames; std::vector<std::string> paramNames;
std::vector<double> paramValues; std::map<std::string, int> paramIndices;
std::map<std::string, double> paramValues;
std::map<int, int> exceptionIndex; std::map<int, int> exceptionIndex;
double ewaldSelfEnergy, dispersionCoefficient, alpha, dispersionAlpha, totalCharge; double ewaldSelfEnergy, dispersionCoefficient, alpha, dispersionAlpha, totalCharge;
int gridSizeX, gridSizeY, gridSizeZ; int gridSizeX, gridSizeY, gridSizeZ;
......
...@@ -300,9 +300,6 @@ private: ...@@ -300,9 +300,6 @@ private:
ForceInfo* info; ForceInfo* info;
const System& system; const System& system;
ComputeParameterSet* params; ComputeParameterSet* params;
ComputeArray globals;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
}; };
/** /**
...@@ -390,9 +387,6 @@ private: ...@@ -390,9 +387,6 @@ private:
ForceInfo* info; ForceInfo* info;
const System& system; const System& system;
ComputeParameterSet* params; ComputeParameterSet* params;
ComputeArray globals;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
}; };
/** /**
...@@ -522,9 +516,6 @@ private: ...@@ -522,9 +516,6 @@ private:
ForceInfo* info; ForceInfo* info;
const System& system; const System& system;
ComputeParameterSet* params; ComputeParameterSet* params;
ComputeArray globals;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
}; };
/** /**
...@@ -613,9 +604,6 @@ private: ...@@ -613,9 +604,6 @@ private:
ForceInfo* info; ForceInfo* info;
const System& system; const System& system;
ComputeParameterSet* params; ComputeParameterSet* params;
ComputeArray globals;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
}; };
/** /**
...@@ -657,9 +645,6 @@ private: ...@@ -657,9 +645,6 @@ private:
ComputeContext& cc; ComputeContext& cc;
ForceInfo* info; ForceInfo* info;
ComputeParameterSet* params; ComputeParameterSet* params;
ComputeArray globals;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, int> tabulatedFunctionUpdateCount; std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system; const System& system;
...@@ -701,14 +686,12 @@ public: ...@@ -701,14 +686,12 @@ public:
private: private:
class ForceInfo; class ForceInfo;
int numGroups, numBonds; int numGroups, numBonds;
bool needEnergyParamDerivs; bool needGlobalParams, needEnergyParamDerivs;
ComputeContext& cc; ComputeContext& cc;
ForceInfo* info; ForceInfo* info;
ComputeParameterSet* params; ComputeParameterSet* params;
ComputeArray globals, groupParticles, groupWeights, groupOffsets; ComputeArray groupParticles, groupWeights, groupOffsets;
ComputeArray groupForces, bondGroups, centerPositions; ComputeArray groupForces, bondGroups, centerPositions;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, int> tabulatedFunctionUpdateCount; std::map<std::string, int> tabulatedFunctionUpdateCount;
std::vector<void*> groupForcesArgs; std::vector<void*> groupForcesArgs;
......
...@@ -144,7 +144,7 @@ public: ...@@ -144,7 +144,7 @@ public:
/** /**
* Get the ContextImpl is ComputeContext is associated with. * Get the ContextImpl is ComputeContext is associated with.
*/ */
virtual ContextImpl& getContextImpl() = 0; virtual ContextImpl* getContextImpl() = 0;
/** /**
* Get a workspace used for accumulating energy when a simulation is parallelized across * Get a workspace used for accumulating energy when a simulation is parallelized across
* multiple devices. * multiple devices.
...@@ -589,6 +589,26 @@ public: ...@@ -589,6 +589,26 @@ public:
* expense of reduced simulation performance. * expense of reduced simulation performance.
*/ */
virtual void flushQueue() = 0; virtual void flushQueue() = 0;
/**
* Register a global parameter whose value should be stored in the array returned by
* getGlobalParamValues(). This may safely be called multiple times with the same
* parameter name, and it returns the same index each time.
*
* @param name the name of the parameter to register
* @return the index of the parameter in the array
*/
int registerGlobalParam(const std::string& name);
/**
* Get the array which contains the values of global parameters.
*/
ArrayInterface& getGlobalParamValues() {
return globalParamValues;
}
/**
* Make sure the values stored in the array returned by getGlobalParamValues() are
* up to date.
*/
void updateGlobalParamValues();
protected: protected:
struct Molecule; struct Molecule;
struct MoleculeGroup; struct MoleculeGroup;
...@@ -604,7 +624,7 @@ protected: ...@@ -604,7 +624,7 @@ protected:
double time; double time;
int numAtoms, paddedNumAtoms, computeForceCount, stepsSinceReorder; int numAtoms, paddedNumAtoms, computeForceCount, stepsSinceReorder;
long long stepCount; long long stepCount;
bool forceNextReorder, atomsWereReordered, forcesValid; bool forceNextReorder, atomsWereReordered, forcesValid, hasInitializedGlobals;
ComputeQueue defaultQueue, currentQueue; ComputeQueue defaultQueue, currentQueue;
std::vector<ComputeForceInfo*> forces; std::vector<ComputeForceInfo*> forces;
std::vector<Molecule> molecules; std::vector<Molecule> molecules;
...@@ -614,6 +634,9 @@ protected: ...@@ -614,6 +634,9 @@ protected:
std::vector<ReorderListener*> reorderListeners; std::vector<ReorderListener*> reorderListeners;
std::vector<ForcePreComputation*> preComputations; std::vector<ForcePreComputation*> preComputations;
std::vector<ForcePostComputation*> postComputations; std::vector<ForcePostComputation*> postComputations;
std::vector<std::string> globalParamNames;
std::vector<double> lastGlobalParamValues;
ComputeArray globalParamValues;
WorkThread* workThread; WorkThread* workThread;
}; };
......
...@@ -91,6 +91,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -91,6 +91,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
throw OpenMMException("CustomGBForce does not support using multiple devices"); throw OpenMMException("CustomGBForce does not support using multiple devices");
NonbondedUtilities& nb = cc.getNonbondedUtilities(); NonbondedUtilities& nb = cc.getNonbondedUtilities();
cutoff = force.getCutoffDistance(); cutoff = force.getCutoffDistance();
needGlobalParams = (force.getNumGlobalParameters() > 0);
bool useExclusionsForValue = false; bool useExclusionsForValue = false;
numComputedValues = force.getNumComputedValues(); numComputedValues = force.getNumComputedValues();
vector<string> computedValueNames(numComputedValues); vector<string> computedValueNames(numComputedValues);
...@@ -119,8 +120,6 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -119,8 +120,6 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
int numParams = force.getNumPerParticleParameters(); int numParams = force.getNumPerParticleParameters();
params = new ComputeParameterSet(cc, force.getNumPerParticleParameters(), paddedNumParticles, "customGBParameters", true); params = new ComputeParameterSet(cc, force.getNumPerParticleParameters(), paddedNumParticles, "customGBParameters", true);
computedValues = new ComputeParameterSet(cc, numComputedValues, paddedNumParticles, "customGBComputedValues", true, cc.getUseDoublePrecision()); computedValues = new ComputeParameterSet(cc, numComputedValues, paddedNumParticles, "customGBComputedValues", true, cc.getUseDoublePrecision());
if (force.getNumGlobalParameters() > 0)
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customGBGlobals");
vector<vector<float> > paramVector(paddedNumParticles, vector<float>(numParams, 0)); vector<vector<float> > paramVector(paddedNumParticles, vector<float>(numParams, 0));
vector<vector<int> > exclusionList(numParticles); vector<vector<int> > exclusionList(numParticles);
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
...@@ -163,17 +162,6 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -163,17 +162,6 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
tableArgs << "* RESTRICT " << arrayName; tableArgs << "* RESTRICT " << arrayName;
} }
// Record the global parameters.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
if (globals.isInitialized())
globals.upload(globalParamValues);
// Record derivatives of expressions needed for the chain rule terms. // Record derivatives of expressions needed for the chain rule terms.
vector<vector<Lepton::ParsedExpression> > valueGradientExpressions(numComputedValues); vector<vector<Lepton::ParsedExpression> > valueGradientExpressions(numComputedValues);
...@@ -221,7 +209,6 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -221,7 +209,6 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
energyParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).optimize()); energyParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).optimize());
} }
bool deviceIsCpu = cc.getIsCPU(); bool deviceIsCpu = cc.getIsCPU();
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
valueBuffers.initialize<long long>(cc, cc.getPaddedNumAtoms(), "customGBValueBuffers"); valueBuffers.initialize<long long>(cc, cc.getPaddedNumAtoms(), "customGBValueBuffers");
longEnergyDerivs.initialize<long long>(cc, numComputedValues*cc.getPaddedNumAtoms(), "customGBLongEnergyDerivatives"); longEnergyDerivs.initialize<long long>(cc, numComputedValues*cc.getPaddedNumAtoms(), "customGBLongEnergyDerivatives");
energyDerivs = new ComputeParameterSet(cc, numComputedValues, cc.getPaddedNumAtoms(), "customGBEnergyDerivatives", true); energyDerivs = new ComputeParameterSet(cc, numComputedValues, cc.getPaddedNumAtoms(), "customGBEnergyDerivatives", true);
...@@ -260,7 +247,8 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -260,7 +247,8 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
} }
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
string value = "globals["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(name);
string value = "globals["+cc.intToString(index)+"]";
variables.push_back(makeVariable(name, value)); variables.push_back(makeVariable(name, value));
} }
map<string, Lepton::ParsedExpression> n2ValueExpressions; map<string, Lepton::ParsedExpression> n2ValueExpressions;
...@@ -281,7 +269,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -281,7 +269,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
replacements["COMPUTE_VALUE"] = n2ValueStr; replacements["COMPUTE_VALUE"] = n2ValueStr;
stringstream extraArgs, atomParams, loadLocal1, loadLocal2, load1, load2, tempDerivs1, tempDerivs2, storeDeriv1, storeDeriv2; stringstream extraArgs, atomParams, loadLocal1, loadLocal2, load1, load2, tempDerivs1, tempDerivs2, storeDeriv1, storeDeriv2;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const float* globals"; extraArgs << ", GLOBAL const real* globals";
pairValueUsesParam.resize(params->getParameterInfos().size(), false); pairValueUsesParam.resize(params->getParameterInfos().size(), false);
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) { for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i]; ComputeParameterInfo& buffer = params->getParameterInfos()[i];
...@@ -351,7 +339,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -351,7 +339,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
stringstream reductionSource, extraArgs, deriv0; stringstream reductionSource, extraArgs, deriv0;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const float* globals"; extraArgs << ", GLOBAL const real* globals";
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) { for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i]; ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1); string paramName = "params"+cc.intToString(i+1);
...@@ -378,8 +366,10 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -378,8 +366,10 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables["z"] = "pos.z"; variables["z"] = "pos.z";
for (int i = 0; i < force.getNumPerParticleParameters(); i++) for (int i = 0; i < force.getNumPerParticleParameters(); i++)
variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]"); variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++) {
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(index)+"]";
}
for (int i = 1; i < numComputedValues; i++) { for (int i = 1; i < numComputedValues; i++) {
variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1); variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
map<string, Lepton::ParsedExpression> valueExpressions; map<string, Lepton::ParsedExpression> valueExpressions;
...@@ -433,8 +423,10 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -433,8 +423,10 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables.push_back(makeVariable(computedValueNames[i]+"1", "values"+computedValues->getParameterSuffix(i, "1"))); variables.push_back(makeVariable(computedValueNames[i]+"1", "values"+computedValues->getParameterSuffix(i, "1")));
variables.push_back(makeVariable(computedValueNames[i]+"2", "values"+computedValues->getParameterSuffix(i, "2"))); variables.push_back(makeVariable(computedValueNames[i]+"2", "values"+computedValues->getParameterSuffix(i, "2")));
} }
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++) {
variables.push_back(makeVariable(force.getGlobalParameterName(i), "globals["+cc.intToString(i)+"]")); int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
variables.push_back(makeVariable(force.getGlobalParameterName(i), "globals["+cc.intToString(index)+"]"));
}
stringstream n2EnergySource; stringstream n2EnergySource;
bool anyExclusions = (force.getNumExclusions() > 0); bool anyExclusions = (force.getNumExclusions() > 0);
for (int i = 0; i < force.getNumEnergyTerms(); i++) { for (int i = 0; i < force.getNumEnergyTerms(); i++) {
...@@ -467,7 +459,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -467,7 +459,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
replacements["COMPUTE_INTERACTION"] = n2EnergyStr; replacements["COMPUTE_INTERACTION"] = n2EnergyStr;
stringstream extraArgs, atomParams, loadLocal1, loadLocal2, clearLocal, load1, load2, declare1, recordDeriv, storeDerivs1, storeDerivs2, initParamDerivs, saveParamDerivs; stringstream extraArgs, atomParams, loadLocal1, loadLocal2, clearLocal, load1, load2, declare1, recordDeriv, storeDerivs1, storeDerivs2, initParamDerivs, saveParamDerivs;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const float* globals"; extraArgs << ", GLOBAL const real* globals";
pairEnergyUsesParam.resize(params->getParameterInfos().size(), false); pairEnergyUsesParam.resize(params->getParameterInfos().size(), false);
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) { for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i]; ComputeParameterInfo& buffer = params->getParameterInfos()[i];
...@@ -555,7 +547,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -555,7 +547,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
stringstream compute, extraArgs, reduce, initParamDerivs, saveParamDerivs; stringstream compute, extraArgs, reduce, initParamDerivs, saveParamDerivs;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const float* globals"; extraArgs << ", GLOBAL const real* globals";
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) { for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i]; ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1); string paramName = "params"+cc.intToString(i+1);
...@@ -601,8 +593,10 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -601,8 +593,10 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables["z"] = "pos.z"; variables["z"] = "pos.z";
for (int i = 0; i < force.getNumPerParticleParameters(); i++) for (int i = 0; i < force.getNumPerParticleParameters(); i++)
variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]"); variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++) {
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(index)+"]";
}
for (int i = 0; i < numComputedValues; i++) for (int i = 0; i < numComputedValues; i++)
variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]"); variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
map<string, Lepton::ParsedExpression> expressions; map<string, Lepton::ParsedExpression> expressions;
...@@ -671,7 +665,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -671,7 +665,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
stringstream compute, extraArgs, initParamDerivs, saveParamDerivs; stringstream compute, extraArgs, initParamDerivs, saveParamDerivs;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const float* globals"; extraArgs << ", GLOBAL const real* globals";
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) { for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i]; ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1); string paramName = "params"+cc.intToString(i+1);
...@@ -707,8 +701,10 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -707,8 +701,10 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables["z"] = "pos.z"; variables["z"] = "pos.z";
for (int i = 0; i < force.getNumPerParticleParameters(); i++) for (int i = 0; i < force.getNumPerParticleParameters(); i++)
variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]"); variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++) {
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(index)+"]";
}
for (int i = 0; i < numComputedValues; i++) for (int i = 0; i < numComputedValues; i++)
variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]"); variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
if (needParameterGradient) { if (needParameterGradient) {
...@@ -757,7 +753,8 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -757,7 +753,8 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
vector<pair<ExpressionTreeNode, string> > globalVariables; vector<pair<ExpressionTreeNode, string> > globalVariables;
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
string value = "globals["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(force.getGlobalParameterName(i));
string value = "globals["+cc.intToString(index)+"]";
globalVariables.push_back(makeVariable(name, prefix+value)); globalVariables.push_back(makeVariable(name, prefix+value));
} }
vector<pair<ExpressionTreeNode, string> > variables = globalVariables; vector<pair<ExpressionTreeNode, string> > variables = globalVariables;
...@@ -818,10 +815,8 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -818,10 +815,8 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
parameters.push_back(ComputeParameterInfo(buffer.getArray(), paramName, buffer.getComponentType(), buffer.getNumComponents())); parameters.push_back(ComputeParameterInfo(buffer.getArray(), paramName, buffer.getComponentType(), buffer.getNumComponents()));
} }
} }
if (globals.isInitialized()) { if (needGlobalParams)
globals.upload(globalParamValues); arguments.push_back(ComputeParameterInfo(cc.getGlobalParamValues(), prefix+"globals", "real", 1));
arguments.push_back(ComputeParameterInfo(globals, prefix+"globals", "float", 1));
}
nb.addInteraction(useCutoff, usePeriodic, force.getNumExclusions() > 0, cutoff, exclusionList, source, force.getForceGroup()); nb.addInteraction(useCutoff, usePeriodic, force.getNumExclusions() > 0, cutoff, exclusionList, source, force.getForceGroup());
for (auto param : parameters) for (auto param : parameters)
nb.addParameter(param); nb.addParameter(param);
...@@ -835,9 +830,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -835,9 +830,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc); ContextSelector selector(cc);
bool deviceIsCpu = cc.getIsCPU();
NonbondedUtilities& nb = cc.getNonbondedUtilities(); NonbondedUtilities& nb = cc.getNonbondedUtilities();
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
if (!hasInitializedKernels) { if (!hasInitializedKernels) {
hasInitializedKernels = true; hasInitializedKernels = true;
...@@ -893,8 +886,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -893,8 +886,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
} }
else else
pairValueKernel->addArg(numAtomBlocks*(numAtomBlocks+1)/2); pairValueKernel->addArg(numAtomBlocks*(numAtomBlocks+1)/2);
if (globals.isInitialized()) if (needGlobalParams)
pairValueKernel->addArg(globals); pairValueKernel->addArg(cc.getGlobalParamValues());
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) { for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
if (pairValueUsesParam[i]) { if (pairValueUsesParam[i]) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i]; ComputeParameterInfo& buffer = params->getParameterInfos()[i];
...@@ -907,8 +900,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -907,8 +900,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
pairValueKernel->addArg(function); pairValueKernel->addArg(function);
perParticleValueKernel->addArg(cc.getPosq()); perParticleValueKernel->addArg(cc.getPosq());
perParticleValueKernel->addArg(valueBuffers); perParticleValueKernel->addArg(valueBuffers);
if (globals.isInitialized()) if (needGlobalParams)
perParticleValueKernel->addArg(globals); perParticleValueKernel->addArg(cc.getGlobalParamValues());
for (auto& buffer : params->getParameterInfos()) for (auto& buffer : params->getParameterInfos())
perParticleValueKernel->addArg(buffer.getArray()); perParticleValueKernel->addArg(buffer.getArray());
for (auto& buffer : computedValues->getParameterInfos()) for (auto& buffer : computedValues->getParameterInfos())
...@@ -938,8 +931,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -938,8 +931,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
} }
else else
pairEnergyKernel->addArg(numAtomBlocks*(numAtomBlocks+1)/2); pairEnergyKernel->addArg(numAtomBlocks*(numAtomBlocks+1)/2);
if (globals.isInitialized()) if (needGlobalParams)
pairEnergyKernel->addArg(globals); pairEnergyKernel->addArg(cc.getGlobalParamValues());
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) { for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
if (pairEnergyUsesParam[i]) { if (pairEnergyUsesParam[i]) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i]; ComputeParameterInfo& buffer = params->getParameterInfos()[i];
...@@ -960,8 +953,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -960,8 +953,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
perParticleEnergyKernel->addArg(cc.getEnergyBuffer()); perParticleEnergyKernel->addArg(cc.getEnergyBuffer());
perParticleEnergyKernel->addArg(cc.getPosq()); perParticleEnergyKernel->addArg(cc.getPosq());
perParticleEnergyKernel->addArg(cc.getLongForceBuffer()); perParticleEnergyKernel->addArg(cc.getLongForceBuffer());
if (globals.isInitialized()) if (needGlobalParams)
perParticleEnergyKernel->addArg(globals); perParticleEnergyKernel->addArg(cc.getGlobalParamValues());
for (auto& buffer : params->getParameterInfos()) for (auto& buffer : params->getParameterInfos())
perParticleEnergyKernel->addArg(buffer.getArray()); perParticleEnergyKernel->addArg(buffer.getArray());
for (auto& buffer : computedValues->getParameterInfos()) for (auto& buffer : computedValues->getParameterInfos())
...@@ -978,8 +971,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -978,8 +971,8 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
if (needParameterGradient || needEnergyParamDerivs) { if (needParameterGradient || needEnergyParamDerivs) {
gradientChainRuleKernel->addArg(cc.getPosq()); gradientChainRuleKernel->addArg(cc.getPosq());
gradientChainRuleKernel->addArg(cc.getLongForceBuffer()); gradientChainRuleKernel->addArg(cc.getLongForceBuffer());
if (globals.isInitialized()) if (needGlobalParams)
gradientChainRuleKernel->addArg(globals); gradientChainRuleKernel->addArg(cc.getGlobalParamValues());
for (auto& buffer : params->getParameterInfos()) for (auto& buffer : params->getParameterInfos())
gradientChainRuleKernel->addArg(buffer.getArray()); gradientChainRuleKernel->addArg(buffer.getArray());
for (auto& buffer : computedValues->getParameterInfos()) for (auto& buffer : computedValues->getParameterInfos())
...@@ -996,17 +989,6 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -996,17 +989,6 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
gradientChainRuleKernel->addArg(function); gradientChainRuleKernel->addArg(function);
} }
} }
if (globals.isInitialized()) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed)
globals.upload(globalParamValues);
}
pairEnergyKernel->setArg(5, (int) includeEnergy); pairEnergyKernel->setArg(5, (int) includeEnergy);
if (nb.getUseCutoff()) { if (nb.getUseCutoff()) {
setPeriodicBoxArgs(cc, pairValueKernel, 6); setPeriodicBoxArgs(cc, pairValueKernel, 6);
......
...@@ -152,8 +152,6 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -152,8 +152,6 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
acceptors.initialize<mm_int4>(cc, numAcceptors, "customHbondAcceptors"); acceptors.initialize<mm_int4>(cc, numAcceptors, "customHbondAcceptors");
donorParams = new ComputeParameterSet(cc, force.getNumPerDonorParameters(), numDonors, "customHbondDonorParameters"); donorParams = new ComputeParameterSet(cc, force.getNumPerDonorParameters(), numDonors, "customHbondDonorParameters");
acceptorParams = new ComputeParameterSet(cc, force.getNumPerAcceptorParameters(), numAcceptors, "customHbondAcceptorParameters"); acceptorParams = new ComputeParameterSet(cc, force.getNumPerAcceptorParameters(), numAcceptors, "customHbondAcceptorParameters");
if (force.getNumGlobalParameters() > 0)
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customHbondGlobals");
vector<vector<float> > donorParamVector(numDonors); vector<vector<float> > donorParamVector(numDonors);
vector<mm_int4> donorVector(numDonors); vector<mm_int4> donorVector(numDonors);
for (int i = 0; i < numDonors; i++) { for (int i = 0; i < numDonors; i++) {
...@@ -254,14 +252,6 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -254,14 +252,6 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
// Record information about parameters. // Record information about parameters.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
if (globals.isInitialized())
globals.upload(globalParamValues);
map<string, string> variables; map<string, string> variables;
for (int i = 0; i < force.getNumPerDonorParameters(); i++) { for (int i = 0; i < force.getNumPerDonorParameters(); i++) {
const string& name = force.getPerDonorParameterName(i); const string& name = force.getPerDonorParameterName(i);
...@@ -273,8 +263,10 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -273,8 +263,10 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
} }
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
variables[name] = "globals["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(name);
variables[name] = "globals["+cc.intToString(index)+"]";
} }
needGlobalParams = (force.getNumGlobalParameters() > 0);
// Now to generate the kernel. First, it needs to calculate all distances, angles, // Now to generate the kernel. First, it needs to calculate all distances, angles,
// and dihedrals the expression depends on. // and dihedrals the expression depends on.
...@@ -354,7 +346,7 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -354,7 +346,7 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
// Next it needs to load parameters from global memory. // Next it needs to load parameters from global memory.
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const float* RESTRICT globals"; extraArgs << ", GLOBAL const real* RESTRICT globals";
for (int i = 0; i < (int) donorParams->getParameterInfos().size(); i++) { for (int i = 0; i < (int) donorParams->getParameterInfos().size(); i++) {
ComputeParameterInfo& parameter = donorParams->getParameterInfos()[i]; ComputeParameterInfo& parameter = donorParams->getParameterInfos()[i];
extraArgs << ", GLOBAL const "+parameter.getType()+"* RESTRICT donor"+parameter.getName(); extraArgs << ", GLOBAL const "+parameter.getType()+"* RESTRICT donor"+parameter.getName();
...@@ -457,17 +449,6 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl ...@@ -457,17 +449,6 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
if (numDonors == 0 || numAcceptors == 0) if (numDonors == 0 || numAcceptors == 0)
return 0.0; return 0.0;
ContextSelector selector(cc); ContextSelector selector(cc);
if (globals.isInitialized()) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed)
globals.upload(globalParamValues);
}
if (!hasInitializedKernel) { if (!hasInitializedKernel) {
hasInitializedKernel = true; hasInitializedKernel = true;
if (useBoundingBoxes) { if (useBoundingBoxes) {
...@@ -495,8 +476,8 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl ...@@ -495,8 +476,8 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
forceKernel->addArg(acceptorBlockCenter); forceKernel->addArg(acceptorBlockCenter);
forceKernel->addArg(acceptorBlockSize); forceKernel->addArg(acceptorBlockSize);
} }
if (globals.isInitialized()) if (needGlobalParams)
forceKernel->addArg(globals); forceKernel->addArg(cc.getGlobalParamValues());
for (auto& parameter : donorParams->getParameterInfos()) for (auto& parameter : donorParams->getParameterInfos())
forceKernel->addArg(parameter.getArray()); forceKernel->addArg(parameter.getArray());
for (auto& parameter : acceptorParams->getParameterInfos()) for (auto& parameter : acceptorParams->getParameterInfos())
......
...@@ -133,12 +133,6 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c ...@@ -133,12 +133,6 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
// Record information about parameters. // Record information about parameters.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
vector<pair<ExpressionTreeNode, string> > variables; vector<pair<ExpressionTreeNode, string> > variables;
for (int i = 0; i < particlesPerSet; i++) { for (int i = 0; i < particlesPerSet; i++) {
string index = cc.intToString(i+1); string index = cc.intToString(i+1);
...@@ -153,14 +147,12 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c ...@@ -153,14 +147,12 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
variables.push_back(makeVariable(name+index, "((real) params"+params->getParameterSuffix(i, index)+")")); variables.push_back(makeVariable(name+index, "((real) params"+params->getParameterSuffix(i, index)+")"));
} }
} }
if (force.getNumGlobalParameters() > 0) { needGlobalParams = (force.getNumGlobalParameters() > 0);
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customManyParticleGlobals"); for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globals.upload(globalParamValues); const string& name = force.getGlobalParameterName(i);
for (int i = 0; i < force.getNumGlobalParameters(); i++) { int index = cc.registerGlobalParam(name);
const string& name = force.getGlobalParameterName(i); string value = "globals["+cc.intToString(index)+"]";
string value = "globals["+cc.intToString(i)+"]"; variables.push_back(makeVariable(name, value));
variables.push_back(makeVariable(name, value));
}
} }
// Build data structures for type filters. // Build data structures for type filters.
...@@ -351,7 +343,7 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c ...@@ -351,7 +343,7 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
stringstream extraArgs; stringstream extraArgs;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", GLOBAL const float* globals"; extraArgs << ", GLOBAL const real* globals";
for (int i = 0; i < (int) params->getParameterInfos().size(); i++) { for (int i = 0; i < (int) params->getParameterInfos().size(); i++) {
ComputeParameterInfo& parameter = params->getParameterInfos()[i]; ComputeParameterInfo& parameter = params->getParameterInfos()[i];
extraArgs<<", GLOBAL const "<<parameter.getType()<<"* RESTRICT global_params"<<(i+1); extraArgs<<", GLOBAL const "<<parameter.getType()<<"* RESTRICT global_params"<<(i+1);
...@@ -423,8 +415,8 @@ double CommonCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bo ...@@ -423,8 +415,8 @@ double CommonCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bo
forceKernel->addArg(exclusions); forceKernel->addArg(exclusions);
forceKernel->addArg(exclusionStartIndex); forceKernel->addArg(exclusionStartIndex);
} }
if (globals.isInitialized()) if (needGlobalParams)
forceKernel->addArg(globals); forceKernel->addArg(cc.getGlobalParamValues());
for (auto& parameter : params->getParameterInfos()) for (auto& parameter : params->getParameterInfos())
forceKernel->addArg(parameter.getArray()); forceKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctionArrays) for (auto& function : tabulatedFunctionArrays)
...@@ -473,17 +465,6 @@ double CommonCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bo ...@@ -473,17 +465,6 @@ double CommonCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bo
copyPairsKernel->addArg(neighborStartIndex); copyPairsKernel->addArg(neighborStartIndex);
} }
} }
if (globals.isInitialized()) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed)
globals.upload(globalParamValues);
}
while (true) { while (true) {
int* numPairs = (int*) cc.getPinnedBuffer(); int* numPairs = (int*) cc.getPinnedBuffer();
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
......
...@@ -152,8 +152,6 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -152,8 +152,6 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
int paddedNumParticles = cc.getPaddedNumAtoms(); int paddedNumParticles = cc.getPaddedNumAtoms();
int numParams = force.getNumPerParticleParameters(); int numParams = force.getNumPerParticleParameters();
params = new ComputeParameterSet(cc, numParams, paddedNumParticles, "customNonbondedParameters", true); params = new ComputeParameterSet(cc, numParams, paddedNumParticles, "customNonbondedParameters", true);
if (force.getNumGlobalParameters() > 0)
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customNonbondedGlobals");
vector<vector<float> > paramVector(paddedNumParticles, vector<float>(numParams, 0)); vector<vector<float> > paramVector(paddedNumParticles, vector<float>(numParams, 0));
vector<vector<int> > exclusionList(numParticles); vector<vector<int> > exclusionList(numParticles);
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
...@@ -211,8 +209,6 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -211,8 +209,6 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
globalParamNames[i] = force.getGlobalParameterName(i); globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i); globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
} }
if (globals.isInitialized())
globals.upload(globalParamValues);
bool useCutoff = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff); bool useCutoff = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff);
bool usePeriodic = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && force.getNonbondedMethod() != CustomNonbondedForce::CutoffNonPeriodic); bool usePeriodic = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && force.getNonbondedMethod() != CustomNonbondedForce::CutoffNonPeriodic);
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize(); Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
...@@ -256,9 +252,11 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -256,9 +252,11 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
variables.push_back(makeVariable(computedValueNames[i]+"1", prefix+"values"+cc.intToString(i+1)+"1")); variables.push_back(makeVariable(computedValueNames[i]+"1", prefix+"values"+cc.intToString(i+1)+"1"));
variables.push_back(makeVariable(computedValueNames[i]+"2", prefix+"values"+cc.intToString(i+1)+"2")); variables.push_back(makeVariable(computedValueNames[i]+"2", prefix+"values"+cc.intToString(i+1)+"2"));
} }
needGlobalParams = (force.getNumGlobalParameters() > 0);
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
string value = "globals["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(name);
string value = "globals["+cc.intToString(index)+"]";
variables.push_back(makeVariable(name, prefix+value)); variables.push_back(makeVariable(name, prefix+value));
} }
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) { for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
...@@ -291,10 +289,8 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -291,10 +289,8 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
for (int i = 0; i < computedValueBuffers.size(); i++) for (int i = 0; i < computedValueBuffers.size(); i++)
cc.getNonbondedUtilities().addParameter(ComputeParameterInfo(computedValueBuffers[i].getArray(), prefix+"values"+cc.intToString(i+1), cc.getNonbondedUtilities().addParameter(ComputeParameterInfo(computedValueBuffers[i].getArray(), prefix+"values"+cc.intToString(i+1),
computedValueBuffers[i].getComponentType(), computedValueBuffers[i].getNumComponents())); computedValueBuffers[i].getComponentType(), computedValueBuffers[i].getNumComponents()));
if (globals.isInitialized()) { if (needGlobalParams)
globals.upload(globalParamValues); cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(cc.getGlobalParamValues(), prefix+"globals", "real", 1));
cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(globals, prefix+"globals", "float", 1));
}
} }
if (force.getNumComputedValues() > 0) { if (force.getNumComputedValues() > 0) {
// Create the kernel to calculate computed values. // Create the kernel to calculate computed values.
...@@ -309,7 +305,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -309,7 +305,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
valuesSource << buffer.getType() << " local_" << valueName << ";\n"; valuesSource << buffer.getType() << " local_" << valueName << ";\n";
} }
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
args << ", GLOBAL const float* globals"; args << ", GLOBAL const real* globals";
for (int i = 0; i < params->getParameterInfos().size(); i++) { for (int i = 0; i < params->getParameterInfos().size(); i++) {
ComputeParameterInfo& buffer = params->getParameterInfos()[i]; ComputeParameterInfo& buffer = params->getParameterInfos()[i];
string paramName = "params"+cc.intToString(i+1); string paramName = "params"+cc.intToString(i+1);
...@@ -318,8 +314,11 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -318,8 +314,11 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
map<string, string> variables; map<string, string> variables;
for (int i = 0; i < force.getNumPerParticleParameters(); i++) for (int i = 0; i < force.getNumPerParticleParameters(); i++)
variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]"); variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++) {
variables[force.getGlobalParameterName(i)] = "globals["+cc.intToString(i)+"]"; const string& name = force.getGlobalParameterName(i);
int index = cc.registerGlobalParam(name);
variables[name] = "globals["+cc.intToString(index)+"]";
}
for (int i = 0; i < force.getNumComputedValues(); i++) { for (int i = 0; i < force.getNumComputedValues(); i++) {
string name, expression; string name, expression;
force.getComputedValueParameters(i, name, expression); force.getComputedValueParameters(i, name, expression);
...@@ -341,8 +340,8 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -341,8 +340,8 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
computedValuesKernel = program->createKernel("computePerParticleValues"); computedValuesKernel = program->createKernel("computePerParticleValues");
for (auto& value : computedValues->getParameterInfos()) for (auto& value : computedValues->getParameterInfos())
computedValuesKernel->addArg(value.getArray()); computedValuesKernel->addArg(value.getArray());
if (globals.isInitialized()) if (needGlobalParams)
computedValuesKernel->addArg(globals); computedValuesKernel->addArg();
for (auto& parameter : params->getParameterInfos()) for (auto& parameter : params->getParameterInfos())
computedValuesKernel->addArg(parameter.getArray()); computedValuesKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctionArrays) for (auto& function : tabulatedFunctionArrays)
...@@ -561,8 +560,8 @@ void CommonCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNon ...@@ -561,8 +560,8 @@ void CommonCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNon
args<<", GLOBAL const "<<computedValueBuffers[i].getType()<<"* RESTRICT global_values"<<(i+1); args<<", GLOBAL const "<<computedValueBuffers[i].getType()<<"* RESTRICT global_values"<<(i+1);
for (int i = 0; i < tabulatedFunctionArrays.size(); i++) for (int i = 0; i < tabulatedFunctionArrays.size(); i++)
args << ", GLOBAL const " << tableTypes[i]<< "* RESTRICT table" << i; args << ", GLOBAL const " << tableTypes[i]<< "* RESTRICT table" << i;
if (globals.isInitialized()) if (needGlobalParams)
args<<", GLOBAL const float* RESTRICT globals"; args<<", GLOBAL const real* RESTRICT globals";
if (hasParamDerivs) if (hasParamDerivs)
args << ", GLOBAL mixed* RESTRICT energyParamDerivs"; args << ", GLOBAL mixed* RESTRICT energyParamDerivs";
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
...@@ -634,18 +633,12 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool ...@@ -634,18 +633,12 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool
} }
ContextSelector selector(cc); ContextSelector selector(cc);
bool recomputeLongRangeCorrection = !hasInitializedLongRangeCorrection; bool recomputeLongRangeCorrection = !hasInitializedLongRangeCorrection;
if (globals.isInitialized()) { if (needGlobalParams && forceCopy != NULL) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) { for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]); float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i]) if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed) {
globals.upload(globalParamValues);
if (forceCopy != NULL)
recomputeLongRangeCorrection = true; recomputeLongRangeCorrection = true;
globalParamValues[i] = value;
} }
} }
if (recomputeLongRangeCorrection) { if (recomputeLongRangeCorrection) {
...@@ -656,8 +649,10 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool ...@@ -656,8 +649,10 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool
else else
hasInitializedLongRangeCorrection = false; hasInitializedLongRangeCorrection = false;
} }
if (computedValues != NULL) if (computedValues != NULL) {
computedValuesKernel->setArg(computedValues->getParameterInfos().size(), cc.getGlobalParamValues());
computedValuesKernel->execute(cc.getNumAtoms()); computedValuesKernel->execute(cc.getNumAtoms());
}
if (interactionGroupData.isInitialized()) { if (interactionGroupData.isInitialized()) {
if (!hasInitializedKernel) { if (!hasInitializedKernel) {
hasInitializedKernel = true; hasInitializedKernel = true;
...@@ -676,8 +671,8 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool ...@@ -676,8 +671,8 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool
interactionGroupKernel->addArg(buffer.getArray()); interactionGroupKernel->addArg(buffer.getArray());
for (auto& function : tabulatedFunctionArrays) for (auto& function : tabulatedFunctionArrays)
interactionGroupKernel->addArg(function); interactionGroupKernel->addArg(function);
if (globals.isInitialized()) if (needGlobalParams)
interactionGroupKernel->addArg(globals); interactionGroupKernel->addArg(cc.getGlobalParamValues());
if (hasParamDerivs) if (hasParamDerivs)
interactionGroupKernel->addArg(cc.getEnergyParamDerivBuffer()); interactionGroupKernel->addArg(cc.getEnergyParamDerivBuffer());
if (useNeighborList) { if (useNeighborList) {
......
...@@ -420,7 +420,7 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons ...@@ -420,7 +420,7 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons
// Create the CPU PME kernel. // Create the CPU PME kernel.
try { try {
cpuPme = getPlatform().createKernel(CalcPmeReciprocalForceKernel::Name(), cc.getContextImpl()); cpuPme = getPlatform().createKernel(CalcPmeReciprocalForceKernel::Name(), *cc.getContextImpl());
cpuPme.getAs<CalcPmeReciprocalForceKernel>().initialize(gridSizeX, gridSizeY, gridSizeZ, numParticles, alpha, false); cpuPme.getAs<CalcPmeReciprocalForceKernel>().initialize(gridSizeX, gridSizeY, gridSizeZ, numParticles, alpha, false);
ComputeProgram program = cc.compileProgram(CommonKernelSources::pme, pmeDefines); ComputeProgram program = cc.compileProgram(CommonKernelSources::pme, pmeDefines);
ComputeKernel addForcesKernel = program->createKernel("addForces"); ComputeKernel addForcesKernel = program->createKernel("addForces");
...@@ -655,15 +655,9 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons ...@@ -655,15 +655,9 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons
int particle; int particle;
double charge, sigma, epsilon; double charge, sigma, epsilon;
force.getParticleParameterOffset(i, param, particle, charge, sigma, epsilon); force.getParticleParameterOffset(i, param, particle, charge, sigma, epsilon);
auto paramPos = find(paramNames.begin(), paramNames.end(), param); int paramIndex = cc.registerGlobalParam(param);
int paramIndex; paramIndices[param] = paramIndex;
if (paramPos == paramNames.end()) { particleOffsetVec[particle].push_back(mm_float4(charge, sigma, epsilon, (float) paramIndex));
paramIndex = paramNames.size();
paramNames.push_back(param);
}
else
paramIndex = paramPos-paramNames.begin();
particleOffsetVec[particle].push_back(mm_float4(charge, sigma, epsilon, paramIndex));
} }
for (int i = 0; i < force.getNumExceptionParameterOffsets(); i++) { for (int i = 0; i < force.getNumExceptionParameterOffsets(); i++) {
string param; string param;
...@@ -673,17 +667,12 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons ...@@ -673,17 +667,12 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons
int index = exceptionIndex[exception]; int index = exceptionIndex[exception];
if (index < startIndex || index >= endIndex) if (index < startIndex || index >= endIndex)
continue; continue;
auto paramPos = find(paramNames.begin(), paramNames.end(), param); int paramIndex = cc.registerGlobalParam(param);
int paramIndex; paramIndices[param] = paramIndex;
if (paramPos == paramNames.end()) { exceptionOffsetVec[index-startIndex].push_back(mm_float4(charge, sigma, epsilon, (float) paramIndex));
paramIndex = paramNames.size();
paramNames.push_back(param);
}
else
paramIndex = paramPos-paramNames.begin();
exceptionOffsetVec[index-startIndex].push_back(mm_float4(charge, sigma, epsilon, paramIndex));
} }
paramValues.resize(paramNames.size(), 0.0); for (int i = 0; i < force.getNumGlobalParameters(); i++)
paramValues[force.getGlobalParameterName(i)] = force.getGlobalParameterDefaultValue(i);
particleParamOffsets.initialize<mm_float4>(cc, max(force.getNumParticleParameterOffsets(), 1), "particleParamOffsets"); particleParamOffsets.initialize<mm_float4>(cc, max(force.getNumParticleParameterOffsets(), 1), "particleParamOffsets");
particleOffsetIndices.initialize<int>(cc, cc.getPaddedNumAtoms()+1, "particleOffsetIndices"); particleOffsetIndices.initialize<int>(cc, cc.getPaddedNumAtoms()+1, "particleOffsetIndices");
vector<int> particleOffsetIndicesVec, exceptionOffsetIndicesVec; vector<int> particleOffsetIndicesVec, exceptionOffsetIndicesVec;
...@@ -711,9 +700,6 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons ...@@ -711,9 +700,6 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons
exceptionParamOffsets.upload(e); exceptionParamOffsets.upload(e);
exceptionOffsetIndices.upload(exceptionOffsetIndicesVec); exceptionOffsetIndices.upload(exceptionOffsetIndicesVec);
} }
globalParams.initialize(cc, max((int) paramValues.size(), 1), cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float), "globalParams");
if (paramValues.size() > 0)
globalParams.upload(paramValues, true);
chargeBuffer.initialize(cc, cc.getNumThreadBlocks(), cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float), "chargeBuffer"); chargeBuffer.initialize(cc, cc.getNumThreadBlocks(), cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float), "chargeBuffer");
cc.clearBuffer(chargeBuffer); cc.clearBuffer(chargeBuffer);
recomputeParams = true; recomputeParams = true;
...@@ -734,7 +720,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -734,7 +720,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
hasInitializedKernel = true; hasInitializedKernel = true;
computeParamsKernel->addArg(cc.getEnergyBuffer()); computeParamsKernel->addArg(cc.getEnergyBuffer());
computeParamsKernel->addArg(); computeParamsKernel->addArg();
computeParamsKernel->addArg(globalParams); computeParamsKernel->addArg(cc.getGlobalParamValues());
computeParamsKernel->addArg(cc.getPaddedNumAtoms()); computeParamsKernel->addArg(cc.getPaddedNumAtoms());
computeParamsKernel->addArg(baseParticleParams); computeParamsKernel->addArg(baseParticleParams);
computeParamsKernel->addArg(cc.getPosq()); computeParamsKernel->addArg(cc.getPosq());
...@@ -894,18 +880,13 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -894,18 +880,13 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
// Update particle and exception parameters. // Update particle and exception parameters.
bool paramChanged = false; for (auto param : paramValues) {
for (int i = 0; i < paramNames.size(); i++) { double value = context.getParameter(param.first);
double value = context.getParameter(paramNames[i]); if (value != param.second) {
if (value != paramValues[i]) { paramValues[param.first] = value;
paramValues[i] = value;; recomputeParams = true;
paramChanged = true;
} }
} }
if (paramChanged) {
recomputeParams = true;
globalParams.upload(paramValues, true);
}
double energy = 0.0; double energy = 0.0;
if (includeReciprocal && (pmeGrid1.isInitialized() || cosSinSums.isInitialized())) { if (includeReciprocal && (pmeGrid1.isInitialized() || cosSinSums.isInitialized())) {
Vec3 a, b, c; Vec3 a, b, c;
...@@ -1223,11 +1204,10 @@ void CommonCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& contex ...@@ -1223,11 +1204,10 @@ void CommonCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& contex
int particle; int particle;
double charge, sigma, epsilon; double charge, sigma, epsilon;
force.getParticleParameterOffset(i, param, particle, charge, sigma, epsilon); force.getParticleParameterOffset(i, param, particle, charge, sigma, epsilon);
auto paramPos = find(paramNames.begin(), paramNames.end(), param); auto paramIndex = paramIndices.find(param);
if (paramPos == paramNames.end()) if (paramIndex == paramIndices.end())
throw OpenMMException("updateParametersInContext: The parameter of a particle parameter offset has changed"); throw OpenMMException("updateParametersInContext: The parameter of a particle parameter offset has changed");
int paramIndex = paramPos-paramNames.begin(); particleOffsetVec[particle].push_back(mm_float4(charge, sigma, epsilon, paramIndex->second));
particleOffsetVec[particle].push_back(mm_float4(charge, sigma, epsilon, paramIndex));
} }
for (int i = 0; i < force.getNumExceptionParameterOffsets(); i++) { for (int i = 0; i < force.getNumExceptionParameterOffsets(); i++) {
string param; string param;
...@@ -1237,11 +1217,10 @@ void CommonCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& contex ...@@ -1237,11 +1217,10 @@ void CommonCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& contex
int index = exceptionIndex[exception]; int index = exceptionIndex[exception];
if (index < startIndex || index >= endIndex) if (index < startIndex || index >= endIndex)
continue; continue;
auto paramPos = find(paramNames.begin(), paramNames.end(), param); auto paramIndex = paramIndices.find(param);
if (paramPos == paramNames.end()) if (paramIndex == paramIndices.end())
throw OpenMMException("updateParametersInContext: The parameter of an exception parameter offset has changed"); throw OpenMMException("updateParametersInContext: The parameter of an exception parameter offset has changed");
int paramIndex = paramPos-paramNames.begin(); exceptionOffsetVec[index-startIndex].push_back(mm_float4(charge, sigma, epsilon, paramIndex->second));
exceptionOffsetVec[index-startIndex].push_back(mm_float4(charge, sigma, epsilon, paramIndex));
} }
if (max(force.getNumParticleParameterOffsets(), 1) != particleParamOffsets.getSize()) if (max(force.getNumParticleParameterOffsets(), 1) != particleParamOffsets.getSize())
throw OpenMMException("updateParametersInContext: The number of particle parameter offsets has changed"); throw OpenMMException("updateParametersInContext: The number of particle parameter offsets has changed");
......
...@@ -574,12 +574,6 @@ void CommonCalcCustomBondForceKernel::initialize(const System& system, const Cus ...@@ -574,12 +574,6 @@ void CommonCalcCustomBondForceKernel::initialize(const System& system, const Cus
// Record information for the expressions. // Record information for the expressions.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize(); Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize(); Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
map<string, Lepton::ParsedExpression> expressions; map<string, Lepton::ParsedExpression> expressions;
...@@ -595,12 +589,11 @@ void CommonCalcCustomBondForceKernel::initialize(const System& system, const Cus ...@@ -595,12 +589,11 @@ void CommonCalcCustomBondForceKernel::initialize(const System& system, const Cus
variables[name] = "bondParams"+params->getParameterSuffix(i); variables[name] = "bondParams"+params->getParameterSuffix(i);
} }
if (force.getNumGlobalParameters() > 0) { if (force.getNumGlobalParameters() > 0) {
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customBondGlobals"); string argName = cc.getBondedUtilities().addArgument(cc.getGlobalParamValues(), "real");
globals.upload(globalParamValues);
string argName = cc.getBondedUtilities().addArgument(globals, "float");
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
string value = argName+"["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(name);
string value = argName+"["+cc.intToString(index)+"]";
variables[name] = value; variables[name] = value;
} }
} }
...@@ -626,18 +619,6 @@ void CommonCalcCustomBondForceKernel::initialize(const System& system, const Cus ...@@ -626,18 +619,6 @@ void CommonCalcCustomBondForceKernel::initialize(const System& system, const Cus
} }
double CommonCalcCustomBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcCustomBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
if (globals.isInitialized()) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed)
globals.upload(globalParamValues);
}
return 0.0; return 0.0;
} }
...@@ -809,12 +790,6 @@ void CommonCalcCustomAngleForceKernel::initialize(const System& system, const Cu ...@@ -809,12 +790,6 @@ void CommonCalcCustomAngleForceKernel::initialize(const System& system, const Cu
// Record information for the expressions. // Record information for the expressions.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize(); Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
Lepton::ParsedExpression forceExpression = energyExpression.differentiate("theta").optimize(); Lepton::ParsedExpression forceExpression = energyExpression.differentiate("theta").optimize();
map<string, Lepton::ParsedExpression> expressions; map<string, Lepton::ParsedExpression> expressions;
...@@ -830,12 +805,11 @@ void CommonCalcCustomAngleForceKernel::initialize(const System& system, const Cu ...@@ -830,12 +805,11 @@ void CommonCalcCustomAngleForceKernel::initialize(const System& system, const Cu
variables[name] = "angleParams"+params->getParameterSuffix(i); variables[name] = "angleParams"+params->getParameterSuffix(i);
} }
if (force.getNumGlobalParameters() > 0) { if (force.getNumGlobalParameters() > 0) {
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customAngleGlobals"); string argName = cc.getBondedUtilities().addArgument(cc.getGlobalParamValues(), "real");
globals.upload(globalParamValues);
string argName = cc.getBondedUtilities().addArgument(globals, "float");
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
string value = argName+"["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(name);
string value = argName+"["+cc.intToString(index)+"]";
variables[name] = value; variables[name] = value;
} }
} }
...@@ -861,18 +835,6 @@ void CommonCalcCustomAngleForceKernel::initialize(const System& system, const Cu ...@@ -861,18 +835,6 @@ void CommonCalcCustomAngleForceKernel::initialize(const System& system, const Cu
} }
double CommonCalcCustomAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcCustomAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
if (globals.isInitialized()) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed)
globals.upload(globalParamValues);
}
return 0.0; return 0.0;
} }
...@@ -1138,12 +1100,6 @@ void CommonCalcCustomTorsionForceKernel::initialize(const System& system, const ...@@ -1138,12 +1100,6 @@ void CommonCalcCustomTorsionForceKernel::initialize(const System& system, const
// Record information for the expressions. // Record information for the expressions.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize(); Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
Lepton::ParsedExpression forceExpression = energyExpression.differentiate("theta").optimize(); Lepton::ParsedExpression forceExpression = energyExpression.differentiate("theta").optimize();
map<string, Lepton::ParsedExpression> expressions; map<string, Lepton::ParsedExpression> expressions;
...@@ -1159,12 +1115,11 @@ void CommonCalcCustomTorsionForceKernel::initialize(const System& system, const ...@@ -1159,12 +1115,11 @@ void CommonCalcCustomTorsionForceKernel::initialize(const System& system, const
variables[name] = "torsionParams"+params->getParameterSuffix(i); variables[name] = "torsionParams"+params->getParameterSuffix(i);
} }
if (force.getNumGlobalParameters() > 0) { if (force.getNumGlobalParameters() > 0) {
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customTorsionGlobals"); string argName = cc.getBondedUtilities().addArgument(cc.getGlobalParamValues(), "real");
globals.upload(globalParamValues);
string argName = cc.getBondedUtilities().addArgument(globals, "float");
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
string value = argName+"["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(name);
string value = argName+"["+cc.intToString(index)+"]";
variables[name] = value; variables[name] = value;
} }
} }
...@@ -1190,18 +1145,6 @@ void CommonCalcCustomTorsionForceKernel::initialize(const System& system, const ...@@ -1190,18 +1145,6 @@ void CommonCalcCustomTorsionForceKernel::initialize(const System& system, const
} }
double CommonCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
if (globals.isInitialized()) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed)
globals.upload(globalParamValues);
}
return 0.0; return 0.0;
} }
...@@ -1411,12 +1354,6 @@ void CommonCalcCustomExternalForceKernel::initialize(const System& system, const ...@@ -1411,12 +1354,6 @@ void CommonCalcCustomExternalForceKernel::initialize(const System& system, const
// Record information for the expressions. // Record information for the expressions.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
map<string, Lepton::CustomFunction*> customFunctions; map<string, Lepton::CustomFunction*> customFunctions;
customFunctions["periodicdistance"] = cc.getExpressionUtilities().getPeriodicDistancePlaceholder(); customFunctions["periodicdistance"] = cc.getExpressionUtilities().getPeriodicDistancePlaceholder();
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), customFunctions).optimize(); Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), customFunctions).optimize();
...@@ -1440,12 +1377,11 @@ void CommonCalcCustomExternalForceKernel::initialize(const System& system, const ...@@ -1440,12 +1377,11 @@ void CommonCalcCustomExternalForceKernel::initialize(const System& system, const
variables[name] = "particleParams"+params->getParameterSuffix(i); variables[name] = "particleParams"+params->getParameterSuffix(i);
} }
if (force.getNumGlobalParameters() > 0) { if (force.getNumGlobalParameters() > 0) {
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customExternalGlobals"); string argName = cc.getBondedUtilities().addArgument(cc.getGlobalParamValues(), "real");
globals.upload(globalParamValues);
string argName = cc.getBondedUtilities().addArgument(globals, "float");
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
string value = argName+"["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(name);
string value = argName+"["+cc.intToString(index)+"]";
variables[name] = value; variables[name] = value;
} }
} }
...@@ -1464,18 +1400,6 @@ void CommonCalcCustomExternalForceKernel::initialize(const System& system, const ...@@ -1464,18 +1400,6 @@ void CommonCalcCustomExternalForceKernel::initialize(const System& system, const
} }
double CommonCalcCustomExternalForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcCustomExternalForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
if (globals.isInitialized()) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed)
globals.upload(globalParamValues);
}
return 0.0; return 0.0;
} }
...@@ -1575,12 +1499,6 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -1575,12 +1499,6 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
// Record information about parameters. // Record information about parameters.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
map<string, string> variables; map<string, string> variables;
for (int i = 0; i < particlesPerBond; i++) { for (int i = 0; i < particlesPerBond; i++) {
string index = cc.intToString(i+1); string index = cc.intToString(i+1);
...@@ -1593,12 +1511,11 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -1593,12 +1511,11 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
variables[name] = "bondParams"+params->getParameterSuffix(i); variables[name] = "bondParams"+params->getParameterSuffix(i);
} }
if (force.getNumGlobalParameters() > 0) { if (force.getNumGlobalParameters() > 0) {
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customCompoundBondGlobals"); string argName = cc.getBondedUtilities().addArgument(cc.getGlobalParamValues(), "real");
globals.upload(globalParamValues);
string argName = cc.getBondedUtilities().addArgument(globals, "float");
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
string value = argName+"["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(name);
string value = argName+"["+cc.intToString(index)+"]";
variables[name] = value; variables[name] = value;
} }
} }
...@@ -1644,18 +1561,6 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -1644,18 +1561,6 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
} }
double CommonCalcCustomCompoundBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcCustomCompoundBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
ContextSelector selector(cc);
if (globals.isInitialized()) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed)
globals.upload(globalParamValues);
}
return 0.0; return 0.0;
} }
...@@ -1827,12 +1732,6 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c ...@@ -1827,12 +1732,6 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
// Record information about parameters. // Record information about parameters.
globalParamNames.resize(force.getNumGlobalParameters());
globalParamValues.resize(force.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
map<string, string> variables; map<string, string> variables;
for (int i = 0; i < groupsPerBond; i++) { for (int i = 0; i < groupsPerBond; i++) {
string index = cc.intToString(i+1); string index = cc.intToString(i+1);
...@@ -1847,13 +1746,13 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c ...@@ -1847,13 +1746,13 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
needEnergyParamDerivs = (force.getNumEnergyParameterDerivatives() > 0); needEnergyParamDerivs = (force.getNumEnergyParameterDerivatives() > 0);
if (needEnergyParamDerivs) if (needEnergyParamDerivs)
extraArgs << ", GLOBAL mixed* RESTRICT energyParamDerivs"; extraArgs << ", GLOBAL mixed* RESTRICT energyParamDerivs";
if (force.getNumGlobalParameters() > 0) { needGlobalParams = (force.getNumGlobalParameters() > 0);
globals.initialize<float>(cc, force.getNumGlobalParameters(), "customCentroidBondGlobals"); if (needGlobalParams) {
globals.upload(globalParamValues); extraArgs << ", GLOBAL const real* RESTRICT globals";
extraArgs << ", GLOBAL const float* RESTRICT globals";
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i); const string& name = force.getGlobalParameterName(i);
string value = "globals["+cc.intToString(i)+"]"; int index = cc.registerGlobalParam(name);
string value = "globals["+cc.intToString(index)+"]";
variables[name] = value; variables[name] = value;
} }
} }
...@@ -1942,8 +1841,8 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c ...@@ -1942,8 +1841,8 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
groupForcesKernel->addArg(); // Deriv buffer hasn't been created yet. groupForcesKernel->addArg(); // Deriv buffer hasn't been created yet.
for (auto& function : tabulatedFunctionArrays) for (auto& function : tabulatedFunctionArrays)
groupForcesKernel->addArg(function); groupForcesKernel->addArg(function);
if (globals.isInitialized()) if (needGlobalParams)
groupForcesKernel->addArg(globals); groupForcesKernel->addArg();
for (auto& parameter : params->getParameterInfos()) for (auto& parameter : params->getParameterInfos())
groupForcesKernel->addArg(parameter.getArray()); groupForcesKernel->addArg(parameter.getArray());
applyForcesKernel = program->createKernel("applyForcesToAtoms"); applyForcesKernel = program->createKernel("applyForcesToAtoms");
...@@ -1959,22 +1858,17 @@ double CommonCalcCustomCentroidBondForceKernel::execute(ContextImpl& context, bo ...@@ -1959,22 +1858,17 @@ double CommonCalcCustomCentroidBondForceKernel::execute(ContextImpl& context, bo
if (numBonds == 0) if (numBonds == 0)
return 0.0; return 0.0;
ContextSelector selector(cc); ContextSelector selector(cc);
if (globals.isInitialized()) {
bool changed = false;
for (int i = 0; i < (int) globalParamNames.size(); i++) {
float value = (float) context.getParameter(globalParamNames[i]);
if (value != globalParamValues[i])
changed = true;
globalParamValues[i] = value;
}
if (changed)
globals.upload(globalParamValues);
}
computeCentersKernel->execute(32*numGroups); computeCentersKernel->execute(32*numGroups);
groupForcesKernel->setArg(2, cc.getEnergyBuffer()); groupForcesKernel->setArg(2, cc.getEnergyBuffer());
setPeriodicBoxArgs(cc, groupForcesKernel, 5); setPeriodicBoxArgs(cc, groupForcesKernel, 5);
if (needEnergyParamDerivs) if (needEnergyParamDerivs)
groupForcesKernel->setArg(10, cc.getEnergyParamDerivBuffer()); groupForcesKernel->setArg(10, cc.getEnergyParamDerivBuffer());
if (needGlobalParams) {
int index = 10+tabulatedFunctionArrays.size();
if (needEnergyParamDerivs)
index += 1;
groupForcesKernel->setArg(index, cc.getGlobalParamValues());
}
groupForcesKernel->execute(numBonds); groupForcesKernel->execute(numBonds);
applyForcesKernel->setArg(5, cc.getLongForceBuffer()); applyForcesKernel->setArg(5, cc.getLongForceBuffer());
applyForcesKernel->execute(32*numGroups); applyForcesKernel->execute(32*numGroups);
......
...@@ -45,7 +45,7 @@ const int ComputeContext::ThreadBlockSize = 64; ...@@ -45,7 +45,7 @@ const int ComputeContext::ThreadBlockSize = 64;
const int ComputeContext::TileSize = 32; const int ComputeContext::TileSize = 32;
ComputeContext::ComputeContext(const System& system) : system(system), time(0.0), stepCount(0), computeForceCount(0), stepsSinceReorder(99999), ComputeContext::ComputeContext(const System& system) : system(system), time(0.0), stepCount(0), computeForceCount(0), stepsSinceReorder(99999),
forceNextReorder(false), atomsWereReordered(false), forcesValid(false) { forceNextReorder(false), atomsWereReordered(false), forcesValid(false), hasInitializedGlobals(false) {
workThread = new WorkThread(); workThread = new WorkThread();
} }
...@@ -708,6 +708,35 @@ int ComputeContext::findLegalFFTDimension(int minimum) { ...@@ -708,6 +708,35 @@ int ComputeContext::findLegalFFTDimension(int minimum) {
} }
} }
int ComputeContext::registerGlobalParam(const string& name) {
for (int i = 0; i < globalParamNames.size(); i++)
if (globalParamNames[i] == name)
return i;
globalParamNames.push_back(name);
return globalParamNames.size()-1;
}
void ComputeContext::updateGlobalParamValues() {
bool changed = false;
if (!hasInitializedGlobals) {
hasInitializedGlobals = true;
int elementSize = (getUseDoublePrecision() ? sizeof(double) : sizeof(float));
globalParamValues.initialize(*this, max(1, (int) globalParamNames.size()), elementSize, "globalParameters");
lastGlobalParamValues.resize(globalParamValues.getSize(), 0.0);
if (globalParamNames.size() > 0)
changed = true;
}
for (int i = 0; i < globalParamNames.size(); i++) {
double value = getContextImpl()->getParameter(globalParamNames[i]);
if (value != lastGlobalParamValues[i]) {
lastGlobalParamValues[i] = value;
changed = true;
}
}
if (changed)
getGlobalParamValues().upload(lastGlobalParamValues, true);
}
struct ComputeContext::WorkThread::ThreadData { struct ComputeContext::WorkThread::ThreadData {
ThreadData(std::queue<ComputeContext::WorkTask*>& tasks, bool& waiting, bool& finished, bool& threwException, OpenMMException& stashedException, ThreadData(std::queue<ComputeContext::WorkTask*>& tasks, bool& waiting, bool& finished, bool& threwException, OpenMMException& stashedException,
mutex& queueLock, condition_variable& waitForTaskCondition, condition_variable& queueEmptyCondition) : mutex& queueLock, condition_variable& waitForTaskCondition, condition_variable& queueEmptyCondition) :
......
...@@ -156,8 +156,8 @@ public: ...@@ -156,8 +156,8 @@ public:
/** /**
* Get the ContextImpl is ComputeContext is associated with. * Get the ContextImpl is ComputeContext is associated with.
*/ */
ContextImpl& getContextImpl() { ContextImpl* getContextImpl() {
return *platformData.context; return platformData.context;
} }
/** /**
* Get a workspace used for accumulating energy when a simulation is parallelized across * Get a workspace used for accumulating energy when a simulation is parallelized across
......
...@@ -68,7 +68,6 @@ class CudaContext; ...@@ -68,7 +68,6 @@ class CudaContext;
class OPENMM_EXPORT_COMMON CudaNonbondedUtilities : public NonbondedUtilities { class OPENMM_EXPORT_COMMON CudaNonbondedUtilities : public NonbondedUtilities {
public: public:
class ParameterInfo;
CudaNonbondedUtilities(CudaContext& context); CudaNonbondedUtilities(CudaContext& context);
~CudaNonbondedUtilities(); ~CudaNonbondedUtilities();
/** /**
...@@ -92,22 +91,10 @@ public: ...@@ -92,22 +91,10 @@ public:
* Add a per-atom parameter that the default interaction kernel may depend on. * Add a per-atom parameter that the default interaction kernel may depend on.
*/ */
void addParameter(ComputeParameterInfo parameter); void addParameter(ComputeParameterInfo parameter);
/**
* Add a per-atom parameter that the default interaction kernel may depend on.
*
* @deprecated Use the version that takes a ComputeParameterInfo instead.
*/
void addParameter(const ParameterInfo& parameter);
/** /**
* Add an array (other than a per-atom parameter) that should be passed as an argument to the default interaction kernel. * Add an array (other than a per-atom parameter) that should be passed as an argument to the default interaction kernel.
*/ */
void addArgument(ComputeParameterInfo parameter); void addArgument(ComputeParameterInfo parameter);
/**
* Add an array (other than a per-atom parameter) that should be passed as an argument to the default interaction kernel.
*
* @deprecated Use the version that takes a ComputeParameterInfo instead.
*/
void addArgument(const ParameterInfo& parameter);
/** /**
* Register that the interaction kernel will be computing the derivative of the potential energy * Register that the interaction kernel will be computing the derivative of the potential energy
* with respect to a parameter. * with respect to a parameter.
...@@ -295,7 +282,7 @@ public: ...@@ -295,7 +282,7 @@ public:
* @param includeForces whether this kernel should compute forces * @param includeForces whether this kernel should compute forces
* @param includeEnergy whether this kernel should compute potential energy * @param includeEnergy whether this kernel should compute potential energy
*/ */
CUfunction createInteractionKernel(const std::string& source, std::vector<ParameterInfo>& params, std::vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy); CUfunction createInteractionKernel(const std::string& source, std::vector<ComputeParameterInfo>& params, std::vector<ComputeParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy);
/** /**
* Create the set of kernels that will be needed for a particular combination of force groups. * Create the set of kernels that will be needed for a particular combination of force groups.
* *
...@@ -310,6 +297,7 @@ public: ...@@ -310,6 +297,7 @@ public:
private: private:
class KernelSet; class KernelSet;
class BlockSortTrait; class BlockSortTrait;
void initParamArgs();
CudaContext& context; CudaContext& context;
std::map<int, KernelSet> groupKernels; std::map<int, KernelSet> groupKernels;
CudaArray exclusionTiles; CudaArray exclusionTiles;
...@@ -336,14 +324,14 @@ private: ...@@ -336,14 +324,14 @@ private:
unsigned int* pinnedCountBuffer; unsigned int* pinnedCountBuffer;
std::vector<void*> forceArgs, findBlockBoundsArgs, computeSortKeysArgs, sortBoxDataArgs, findInteractingBlocksArgs; std::vector<void*> forceArgs, findBlockBoundsArgs, computeSortKeysArgs, sortBoxDataArgs, findInteractingBlocksArgs;
std::vector<std::vector<int> > atomExclusions; std::vector<std::vector<int> > atomExclusions;
std::vector<ParameterInfo> parameters; std::vector<ComputeParameterInfo> parameters;
std::vector<ParameterInfo> arguments; std::vector<ComputeParameterInfo> arguments;
std::vector<std::string> energyParameterDerivatives; std::vector<std::string> energyParameterDerivatives;
std::map<int, double> groupCutoff; std::map<int, double> groupCutoff;
std::map<int, std::string> groupKernelSource; std::map<int, std::string> groupKernelSource;
double maxCutoff; double maxCutoff;
bool useCutoff, usePeriodic, anyExclusions, usePadding, useNeighborList, forceRebuildNeighborList, canUsePairList, useLargeBlocks; bool useCutoff, usePeriodic, anyExclusions, usePadding, useNeighborList, forceRebuildNeighborList, canUsePairList, useLargeBlocks, hasInitializedParams;
int startTileIndex, startBlockIndex, numBlocks, maxExclusions, numForceThreadBlocks, forceThreadBlockSize, numAtoms, groupFlags, numBlockSizes; int startTileIndex, startBlockIndex, numBlocks, maxExclusions, numForceThreadBlocks, forceThreadBlockSize, numAtoms, groupFlags, numBlockSizes, paramStartIndex;
unsigned int maxTiles, maxSinglePairs, tilesAfterReorder; unsigned int maxTiles, maxSinglePairs, tilesAfterReorder;
long long numTiles; long long numTiles;
std::string kernelSource; std::string kernelSource;
...@@ -365,62 +353,6 @@ public: ...@@ -365,62 +353,6 @@ public:
CUfunction findInteractionsWithinBlocksKernel; CUfunction findInteractionsWithinBlocksKernel;
}; };
/**
* This class stores information about a per-atom parameter that may be used in a nonbonded kernel.
*/
class CudaNonbondedUtilities::ParameterInfo {
public:
/**
* Create a ParameterInfo object.
*
* @param name the name of the parameter
* @param type the data type of the parameter's components
* @param numComponents the number of components in the parameter
* @param size the size of the parameter in bytes
* @param memory the memory containing the parameter values
* @param constant whether the memory should be marked as constant
*/
ParameterInfo(const std::string& name, const std::string& componentType, int numComponents, int size, CUdeviceptr memory, bool constant=true) :
name(name), componentType(componentType), numComponents(numComponents), size(size), memory(memory), constant(constant) {
if (numComponents == 1)
type = componentType;
else {
std::stringstream s;
s << componentType << numComponents;
type = s.str();
}
}
const std::string& getName() const {
return name;
}
const std::string& getComponentType() const {
return componentType;
}
const std::string& getType() const {
return type;
}
int getNumComponents() const {
return numComponents;
}
int getSize() const {
return size;
}
CUdeviceptr& getMemory() {
return memory;
}
bool isConstant() const {
return constant;
}
private:
std::string name;
std::string componentType;
std::string type;
int size, numComponents;
CUdeviceptr memory;
bool constant;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_CUDANONBONDEDUTILITIES_H_*/ #endif /*OPENMM_CUDANONBONDEDUTILITIES_H_*/
#ifndef OPENMM_CUDAPARAMETERSET_H_
#define OPENMM_CUDAPARAMETERSET_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "CudaContext.h"
#include "CudaNonbondedUtilities.h"
#include "openmm/common/ComputeParameterSet.h"
namespace OpenMM {
class CudaNonbondedUtilities;
/**
* This class exists for backward compatibility. For most purposes you can use
* ComputeParameterSet directly instead.
*/
class OPENMM_EXPORT_COMMON CudaParameterSet : public ComputeParameterSet {
public:
/**
* Create an CudaParameterSet.
*
* @param context the context for which to create the parameter set
* @param numParameters the number of parameters for each object
* @param numObjects the number of objects to store parameter values for
* @param name the name of the parameter set
* @param bufferPerParameter if true, a separate buffer is created for each parameter. If false,
* multiple parameters may be combined into a single buffer.
* @param useDoublePrecision whether values should be stored as single or double precision
*/
CudaParameterSet(CudaContext& context, int numParameters, int numObjects, const std::string& name, bool bufferPerParameter=false, bool useDoublePrecision=false);
/**
* Get a set of CudaNonbondedUtilities::ParameterInfo objects which describe the Buffers
* containing the data.
*/
std::vector<CudaNonbondedUtilities::ParameterInfo>& getBuffers() {
return buffers;
}
private:
std::vector<CudaNonbondedUtilities::ParameterInfo> buffers;
};
} // namespace OpenMM
#endif /*OPENMM_CUDAPARAMETERSET_H_*/
...@@ -61,6 +61,7 @@ void CudaCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool ...@@ -61,6 +61,7 @@ void CudaCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool
cu.setForcesValid(true); cu.setForcesValid(true);
ContextSelector selector(cu); ContextSelector selector(cu);
cu.clearAutoclearBuffers(); cu.clearAutoclearBuffers();
cu.updateGlobalParamValues();
for (auto computation : cu.getPreComputations()) for (auto computation : cu.getPreComputations())
computation->computeForceAndEnergy(includeForces, includeEnergy, groups); computation->computeForceAndEnergy(includeForces, includeEnergy, groups);
CudaNonbondedUtilities& nb = cu.getNonbondedUtilities(); CudaNonbondedUtilities& nb = cu.getNonbondedUtilities();
......
...@@ -115,20 +115,10 @@ void CudaNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, ...@@ -115,20 +115,10 @@ void CudaNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic,
} }
void CudaNonbondedUtilities::addParameter(ComputeParameterInfo parameter) { void CudaNonbondedUtilities::addParameter(ComputeParameterInfo parameter) {
parameters.push_back(ParameterInfo(parameter.getName(), parameter.getComponentType(), parameter.getNumComponents(),
parameter.getSize(), context.unwrap(parameter.getArray()).getDevicePointer(), parameter.isConstant()));
}
void CudaNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
parameters.push_back(parameter); parameters.push_back(parameter);
} }
void CudaNonbondedUtilities::addArgument(ComputeParameterInfo parameter) { void CudaNonbondedUtilities::addArgument(ComputeParameterInfo parameter) {
arguments.push_back(ParameterInfo(parameter.getName(), parameter.getComponentType(), parameter.getNumComponents(),
parameter.getSize(), context.unwrap(parameter.getArray()).getDevicePointer(), parameter.isConstant()));
}
void CudaNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
arguments.push_back(parameter); arguments.push_back(parameter);
} }
...@@ -313,10 +303,10 @@ void CudaNonbondedUtilities::initialize(const System& system) { ...@@ -313,10 +303,10 @@ void CudaNonbondedUtilities::initialize(const System& system) {
forceArgs.push_back(&maxSinglePairs); forceArgs.push_back(&maxSinglePairs);
forceArgs.push_back(&singlePairs.getDevicePointer()); forceArgs.push_back(&singlePairs.getDevicePointer());
} }
for (int i = 0; i < (int) parameters.size(); i++) hasInitializedParams = false;
forceArgs.push_back(&parameters[i].getMemory()); paramStartIndex = forceArgs.size();
for (ParameterInfo& arg : arguments) for (int i = 0; i < parameters.size()+arguments.size(); i++)
forceArgs.push_back(&arg.getMemory()); forceArgs.push_back(NULL);
if (energyParameterDerivatives.size() > 0) if (energyParameterDerivatives.size() > 0)
forceArgs.push_back(&context.getEnergyParamDerivBuffer().getDevicePointer()); forceArgs.push_back(&context.getEnergyParamDerivBuffer().getDevicePointer());
if (useCutoff) { if (useCutoff) {
...@@ -423,6 +413,15 @@ void CudaNonbondedUtilities::prepareInteractions(int forceGroups) { ...@@ -423,6 +413,15 @@ void CudaNonbondedUtilities::prepareInteractions(int forceGroups) {
cuEventRecord(downloadCountEvent, context.getCurrentStream()); cuEventRecord(downloadCountEvent, context.getCurrentStream());
} }
void CudaNonbondedUtilities::initParamArgs() {
int index = paramStartIndex;
for (ComputeParameterInfo& param : parameters)
forceArgs[index++] = &context.unwrap(param.getArray()).getDevicePointer();
for (ComputeParameterInfo& arg : arguments)
forceArgs[index++] = &context.unwrap(arg.getArray()).getDevicePointer();
hasInitializedParams = true;
}
void CudaNonbondedUtilities::computeInteractions(int forceGroups, bool includeForces, bool includeEnergy) { void CudaNonbondedUtilities::computeInteractions(int forceGroups, bool includeForces, bool includeEnergy) {
if ((forceGroups&groupFlags) == 0) if ((forceGroups&groupFlags) == 0)
return; return;
...@@ -431,6 +430,8 @@ void CudaNonbondedUtilities::computeInteractions(int forceGroups, bool includeFo ...@@ -431,6 +430,8 @@ void CudaNonbondedUtilities::computeInteractions(int forceGroups, bool includeFo
CUfunction& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel); CUfunction& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel);
if (kernel == NULL) if (kernel == NULL)
kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy); kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy);
if (!hasInitializedParams)
initParamArgs();
context.executeKernel(kernel, &forceArgs[0], numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize); context.executeKernel(kernel, &forceArgs[0], numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
} }
if (useNeighborList && numTiles > 0) { if (useNeighborList && numTiles > 0) {
...@@ -536,13 +537,13 @@ void CudaNonbondedUtilities::createKernelsForGroups(int groups) { ...@@ -536,13 +537,13 @@ void CudaNonbondedUtilities::createKernelsForGroups(int groups) {
groupKernels[groups] = kernels; groupKernels[groups] = kernels;
} }
CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, vector<ParameterInfo>& params, vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy) { CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, vector<ComputeParameterInfo>& params, vector<ComputeParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy) {
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_INTERACTION"] = source; replacements["COMPUTE_INTERACTION"] = source;
const string suffixes[] = {"x", "y", "z", "w"}; const string suffixes[] = {"x", "y", "z", "w"};
stringstream localData; stringstream localData;
int localDataSize = 0; int localDataSize = 0;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
if (param.getNumComponents() == 1) if (param.getNumComponents() == 1)
localData<<param.getType()<<" "<<param.getName()<<";\n"; localData<<param.getType()<<" "<<param.getName()<<";\n";
else { else {
...@@ -553,7 +554,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, ...@@ -553,7 +554,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source,
} }
replacements["ATOM_PARAMETER_DATA"] = localData.str(); replacements["ATOM_PARAMETER_DATA"] = localData.str();
stringstream args; stringstream args;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
args << ", "; args << ", ";
if (param.isConstant()) if (param.isConstant())
args << "const "; args << "const ";
...@@ -561,7 +562,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, ...@@ -561,7 +562,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source,
args << "* __restrict__ global_"; args << "* __restrict__ global_";
args << param.getName(); args << param.getName();
} }
for (const ParameterInfo& arg : arguments) { for (const ComputeParameterInfo& arg : arguments) {
args << ", "; args << ", ";
if (arg.isConstant()) if (arg.isConstant())
args << "const "; args << "const ";
...@@ -574,7 +575,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, ...@@ -574,7 +575,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source,
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
stringstream load1; stringstream load1;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
load1 << param.getType(); load1 << param.getType();
load1 << " "; load1 << " ";
load1 << param.getName(); load1 << param.getName();
...@@ -591,7 +592,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, ...@@ -591,7 +592,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source,
broadcastWarpData << "posq2.y = real_shfl(shflPosq.y, j);\n"; broadcastWarpData << "posq2.y = real_shfl(shflPosq.y, j);\n";
broadcastWarpData << "posq2.z = real_shfl(shflPosq.z, j);\n"; broadcastWarpData << "posq2.z = real_shfl(shflPosq.z, j);\n";
broadcastWarpData << "posq2.w = real_shfl(shflPosq.w, j);\n"; broadcastWarpData << "posq2.w = real_shfl(shflPosq.w, j);\n";
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
broadcastWarpData << param.getType() << " shfl" << param.getName() << ";\n"; broadcastWarpData << param.getType() << " shfl" << param.getName() << ";\n";
for (int j = 0; j < param.getNumComponents(); j++) { for (int j = 0; j < param.getNumComponents(); j++) {
if (param.getNumComponents() == 1) if (param.getNumComponents() == 1)
...@@ -604,27 +605,27 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, ...@@ -604,27 +605,27 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source,
// Part 2. Defines for off-diagonal exclusions, and neighborlist tiles. // Part 2. Defines for off-diagonal exclusions, and neighborlist tiles.
stringstream declareLocal2; stringstream declareLocal2;
for (const ParameterInfo& param : params) for (const ComputeParameterInfo& param : params)
declareLocal2<<param.getType()<<" shfl"<<param.getName()<<";\n"; declareLocal2<<param.getType()<<" shfl"<<param.getName()<<";\n";
replacements["DECLARE_LOCAL_PARAMETERS"] = declareLocal2.str(); replacements["DECLARE_LOCAL_PARAMETERS"] = declareLocal2.str();
stringstream loadLocal2; stringstream loadLocal2;
for (const ParameterInfo& param : params) for (const ComputeParameterInfo& param : params)
loadLocal2<<"shfl"<<param.getName()<<" = global_"<<param.getName()<<"[j];\n"; loadLocal2<<"shfl"<<param.getName()<<" = global_"<<param.getName()<<"[j];\n";
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
stringstream load2j; stringstream load2j;
for (const ParameterInfo& param : params) for (const ComputeParameterInfo& param : params)
load2j<<param.getType()<<" "<<param.getName()<<"2 = shfl"<<param.getName()<<";\n"; load2j<<param.getType()<<" "<<param.getName()<<"2 = shfl"<<param.getName()<<";\n";
replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str(); replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
stringstream load2g; stringstream load2g;
for (const ParameterInfo& param : params) for (const ComputeParameterInfo& param : params)
load2g<<param.getType()<<" "<<param.getName()<<"2 = global_"<<param.getName()<<"[atom2];\n"; load2g<<param.getType()<<" "<<param.getName()<<"2 = global_"<<param.getName()<<"[atom2];\n";
replacements["LOAD_ATOM2_PARAMETERS_FROM_GLOBAL"] = load2g.str(); replacements["LOAD_ATOM2_PARAMETERS_FROM_GLOBAL"] = load2g.str();
stringstream clearLocal; stringstream clearLocal;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
clearLocal<<"shfl"<<param.getName()<<" = "; clearLocal<<"shfl"<<param.getName()<<" = ";
if (param.getNumComponents() == 1) if (param.getNumComponents() == 1)
clearLocal<<"0;\n"; clearLocal<<"0;\n";
...@@ -654,7 +655,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source, ...@@ -654,7 +655,7 @@ CUfunction CudaNonbondedUtilities::createInteractionKernel(const string& source,
shuffleWarpData << "shflForce.x = real_shfl(shflForce.x, tgx+1);\n"; shuffleWarpData << "shflForce.x = real_shfl(shflForce.x, tgx+1);\n";
shuffleWarpData << "shflForce.y = real_shfl(shflForce.y, tgx+1);\n"; shuffleWarpData << "shflForce.y = real_shfl(shflForce.y, tgx+1);\n";
shuffleWarpData << "shflForce.z = real_shfl(shflForce.z, tgx+1);\n"; shuffleWarpData << "shflForce.z = real_shfl(shflForce.z, tgx+1);\n";
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
if (param.getNumComponents() == 1) if (param.getNumComponents() == 1)
shuffleWarpData<<"shfl"<<param.getName()<<"=real_shfl(shfl"<<param.getName()<<", tgx+1);\n"; shuffleWarpData<<"shfl"<<param.getName()<<"=real_shfl(shfl"<<param.getName()<<", tgx+1);\n";
else { else {
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "CudaParameterSet.h"
using namespace OpenMM;
using namespace std;
CudaParameterSet::CudaParameterSet(CudaContext& context, int numParameters, int numObjects, const string& name, bool bufferPerParameter, bool useDoublePrecision) :
ComputeParameterSet(context, numParameters, numObjects, name, bufferPerParameter, useDoublePrecision) {
for (auto& info : getParameterInfos())
buffers.push_back(CudaNonbondedUtilities::ParameterInfo(info.getName(), info.getComponentType(), info.getNumComponents(), info.getSize(), context.unwrap(info.getArray()).getDevicePointer()));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment