Commit f065df01 authored by Peter Eastman's avatar Peter Eastman
Browse files

OpenCL allows different force groups to have different cutoffs

parent ab8d97b3
...@@ -732,7 +732,7 @@ public: ...@@ -732,7 +732,7 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const GBSAOBCForce& force); void copyParametersToContext(ContextImpl& context, const GBSAOBCForce& force);
private: private:
double prefactor, surfaceAreaFactor; double prefactor, surfaceAreaFactor, cutoff;
bool hasCreatedKernels; bool hasCreatedKernels;
int maxTiles; int maxTiles;
OpenCLContext& cl; OpenCLContext& cl;
...@@ -783,6 +783,7 @@ public: ...@@ -783,6 +783,7 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomGBForce& force); void copyParametersToContext(ContextImpl& context, const CustomGBForce& force);
private: private:
double cutoff;
bool hasInitializedKernels, needParameterGradient; bool hasInitializedKernels, needParameterGradient;
int maxTiles, numComputedValues; int maxTiles, numComputedValues;
OpenCLContext& cl; OpenCLContext& cl;
......
...@@ -135,31 +135,23 @@ public: ...@@ -135,31 +135,23 @@ public:
return forceThreadBlockSize; return forceThreadBlockSize;
} }
/** /**
* Get the cutoff distance. * Get the maximum cutoff distance used by any force group.
*/ */
double getCutoffDistance() { double getMaxCutoffDistance();
return cutoff;
}
/** /**
* Get whether any interactions have been added. * Get whether any interactions have been added.
*/ */
bool getHasInteractions() { bool getHasInteractions() {
return cutoff != -1.0; return (groupCutoff.size() > 0);
}
/**
* Get the force group in which nonbonded interactions should be computed.
*/
int getForceGroup() {
return nonbondedForceGroup;
} }
/** /**
* Prepare to compute interactions. This updates the neighbor list. * Prepare to compute interactions. This updates the neighbor list.
*/ */
void prepareInteractions(); void prepareInteractions(int forceGroups);
/** /**
* Compute the nonbonded interactions. * Compute the nonbonded interactions.
*/ */
void computeInteractions(); void computeInteractions(int forceGroups);
/** /**
* Check to see if the neighbor list arrays are large enough, and make them bigger if necessary. * Check to see if the neighbor list arrays are large enough, and make them bigger if necessary.
*/ */
...@@ -252,16 +244,20 @@ public: ...@@ -252,16 +244,20 @@ public:
* @param arguments arrays (other than per-atom parameters) that should be passed as arguments to the kernel * @param arguments arrays (other than per-atom parameters) that should be passed as arguments to the kernel
* @param useExclusions specifies whether exclusions are applied to this interaction * @param useExclusions specifies whether exclusions are applied to this interaction
* @param isSymmetric specifies whether the interaction is symmetric * @param isSymmetric specifies whether the interaction is symmetric
* @param groups the set of force groups this kernel is for
*/
cl::Kernel createInteractionKernel(const std::string& source, const std::vector<ParameterInfo>& params, const std::vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups);
/**
* Create the set of kernels that will be needed for a particular combination of force groups.
*
* @param groups the set of force groups
*/ */
cl::Kernel createInteractionKernel(const std::string& source, const std::vector<ParameterInfo>& params, const std::vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) const; void createKernelsForGroups(int groups);
private: private:
class KernelSet;
class BlockSortTrait; class BlockSortTrait;
OpenCLContext& context; OpenCLContext& context;
cl::Kernel forceKernel; std::map<int, KernelSet> groupKernels;
cl::Kernel findBlockBoundsKernel;
cl::Kernel sortBoxDataKernel;
cl::Kernel findInteractingBlocksKernel;
cl::Kernel findInteractionsWithinBlocksKernel;
OpenCLArray* exclusionTiles; OpenCLArray* exclusionTiles;
OpenCLArray* exclusions; OpenCLArray* exclusions;
OpenCLArray* exclusionIndices; OpenCLArray* exclusionIndices;
...@@ -280,12 +276,27 @@ private: ...@@ -280,12 +276,27 @@ private:
std::vector<std::vector<int> > atomExclusions; std::vector<std::vector<int> > atomExclusions;
std::vector<ParameterInfo> parameters; std::vector<ParameterInfo> parameters;
std::vector<ParameterInfo> arguments; std::vector<ParameterInfo> arguments;
std::string kernelSource; std::map<int, double> groupCutoff;
std::map<std::string, std::string> kernelDefines; std::map<int, std::string> groupKernelSource;
double cutoff; double lastCutoff;
bool useCutoff, usePeriodic, deviceIsCpu, anyExclusions, usePadding; bool useCutoff, usePeriodic, deviceIsCpu, anyExclusions, usePadding, forceRebuildNeighborList;
int numForceBuffers, startTileIndex, numTiles, startBlockIndex, numBlocks, numForceThreadBlocks; int numForceBuffers, startTileIndex, numTiles, startBlockIndex, numBlocks, maxExclusions, numForceThreadBlocks;
int forceThreadBlockSize, interactingBlocksThreadBlockSize, nonbondedForceGroup; int forceThreadBlockSize, interactingBlocksThreadBlockSize, groupFlags;
};
/**
* This class stores the kernels to execute for a set of force groups.
*/
class OpenCLNonbondedUtilities::KernelSet {
public:
bool hasForces;
double cutoffDistance;
cl::Kernel forceKernel;
cl::Kernel findBlockBoundsKernel;
cl::Kernel sortBoxDataKernel;
cl::Kernel findInteractingBlocksKernel;
cl::Kernel findInteractionsWithinBlocksKernel;
}; };
/** /**
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <set>
#include <sstream> #include <sstream>
#include <typeinfo> #include <typeinfo>
...@@ -492,13 +493,32 @@ void OpenCLContext::addForce(OpenCLForceInfo* force) { ...@@ -492,13 +493,32 @@ void OpenCLContext::addForce(OpenCLForceInfo* force) {
} }
string OpenCLContext::replaceStrings(const string& input, const std::map<std::string, std::string>& replacements) const { string OpenCLContext::replaceStrings(const string& input, const std::map<std::string, std::string>& replacements) const {
static set<char> symbolChars;
if (symbolChars.size() == 0) {
symbolChars.insert('_');
for (char c = 'a'; c <= 'z'; c++)
symbolChars.insert(c);
for (char c = 'A'; c <= 'Z'; c++)
symbolChars.insert(c);
for (char c = '0'; c <= '9'; c++)
symbolChars.insert(c);
}
string result = input; string result = input;
for (map<string, string>::const_iterator iter = replacements.begin(); iter != replacements.end(); iter++) { for (map<string, string>::const_iterator iter = replacements.begin(); iter != replacements.end(); iter++) {
int index = -1; int index = 0;
int size = iter->first.size();
do { do {
index = result.find(iter->first); index = result.find(iter->first, index);
if (index != result.npos) if (index != result.npos) {
result.replace(index, iter->first.size(), iter->second); if ((index == 0 || symbolChars.find(result[index-1]) == symbolChars.end()) && (index == result.size()-size || symbolChars.find(result[index+size]) == symbolChars.end())) {
// We have found a complete symbol, not part of a longer symbol.
result.replace(index, size, iter->second);
index += iter->second.size();
}
else
index++;
}
} while (index != result.npos); } while (index != result.npos);
} }
return result; return result;
...@@ -1130,7 +1150,7 @@ void OpenCLContext::reorderAtomsImpl() { ...@@ -1130,7 +1150,7 @@ void OpenCLContext::reorderAtomsImpl() {
if (useHilbert) if (useHilbert)
binWidth = (Real) (max(max(maxx-minx, maxy-miny), maxz-minz)/255.0); binWidth = (Real) (max(max(maxx-minx, maxy-miny), maxz-minz)/255.0);
else else
binWidth = (Real) (0.2*nonbonded->getCutoffDistance()); binWidth = (Real) (0.2*nonbonded->getMaxCutoffDistance());
Real invBinWidth = (Real) (1.0/binWidth); Real invBinWidth = (Real) (1.0/binWidth);
int xbins = 1 + (int) ((maxx-minx)*invBinWidth); int xbins = 1 + (int) ((maxx-minx)*invBinWidth);
int ybins = 1 + (int) ((maxy-miny)*invBinWidth); int ybins = 1 + (int) ((maxy-miny)*invBinWidth);
......
...@@ -121,16 +121,13 @@ void OpenCLCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, boo ...@@ -121,16 +121,13 @@ void OpenCLCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, boo
for (vector<OpenCLContext::ForcePreComputation*>::iterator iter = cl.getPreComputations().begin(); iter != cl.getPreComputations().end(); ++iter) for (vector<OpenCLContext::ForcePreComputation*>::iterator iter = cl.getPreComputations().begin(); iter != cl.getPreComputations().end(); ++iter)
(*iter)->computeForceAndEnergy(includeForces, includeEnergy, groups); (*iter)->computeForceAndEnergy(includeForces, includeEnergy, groups);
OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities(); OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
bool includeNonbonded = ((groups&(1<<nb.getForceGroup())) != 0);
cl.setComputeForceCount(cl.getComputeForceCount()+1); cl.setComputeForceCount(cl.getComputeForceCount()+1);
if (includeNonbonded) nb.prepareInteractions(groups);
nb.prepareInteractions();
} }
double OpenCLCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups, bool& valid) { double OpenCLCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups, bool& valid) {
cl.getBondedUtilities().computeInteractions(groups); cl.getBondedUtilities().computeInteractions(groups);
if ((groups&(1<<cl.getNonbondedUtilities().getForceGroup())) != 0) cl.getNonbondedUtilities().computeInteractions(groups);
cl.getNonbondedUtilities().computeInteractions();
double sum = 0.0; double sum = 0.0;
for (vector<OpenCLContext::ForcePostComputation*>::iterator iter = cl.getPostComputations().begin(); iter != cl.getPostComputations().end(); ++iter) for (vector<OpenCLContext::ForcePostComputation*>::iterator iter = cl.getPostComputations().begin(); iter != cl.getPostComputations().end(); ++iter)
sum += (*iter)->computeForceAndEnergy(includeForces, includeEnergy, groups); sum += (*iter)->computeForceAndEnergy(includeForces, includeEnergy, groups);
...@@ -2643,8 +2640,9 @@ void OpenCLCalcGBSAOBCForceKernel::initialize(const System& system, const GBSAOB ...@@ -2643,8 +2640,9 @@ void OpenCLCalcGBSAOBCForceKernel::initialize(const System& system, const GBSAOB
surfaceAreaFactor = -6.0*4*M_PI*force.getSurfaceAreaEnergy(); surfaceAreaFactor = -6.0*4*M_PI*force.getSurfaceAreaEnergy();
bool useCutoff = (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff); bool useCutoff = (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff);
bool usePeriodic = (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff && force.getNonbondedMethod() != GBSAOBCForce::CutoffNonPeriodic); bool usePeriodic = (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff && force.getNonbondedMethod() != GBSAOBCForce::CutoffNonPeriodic);
cutoff = force.getCutoffDistance();
string source = OpenCLKernelSources::gbsaObc2; string source = OpenCLKernelSources::gbsaObc2;
nb.addInteraction(useCutoff, usePeriodic, false, force.getCutoffDistance(), vector<vector<int> >(), source, force.getForceGroup()); nb.addInteraction(useCutoff, usePeriodic, false, cutoff, vector<vector<int> >(), source, force.getForceGroup());
nb.addParameter(OpenCLNonbondedUtilities::ParameterInfo("obcParams", "float", 2, sizeof(cl_float2), params->getDeviceBuffer()));; nb.addParameter(OpenCLNonbondedUtilities::ParameterInfo("obcParams", "float", 2, sizeof(cl_float2), params->getDeviceBuffer()));;
nb.addParameter(OpenCLNonbondedUtilities::ParameterInfo("bornForce", "real", 1, elementSize, bornForce->getDeviceBuffer()));; nb.addParameter(OpenCLNonbondedUtilities::ParameterInfo("bornForce", "real", 1, elementSize, bornForce->getDeviceBuffer()));;
cl.addForce(new OpenCLGBSAOBCForceInfo(nb.getNumForceBuffers(), force)); cl.addForce(new OpenCLGBSAOBCForceInfo(nb.getNumForceBuffers(), force));
...@@ -2663,8 +2661,8 @@ double OpenCLCalcGBSAOBCForceKernel::execute(ContextImpl& context, bool includeF ...@@ -2663,8 +2661,8 @@ double OpenCLCalcGBSAOBCForceKernel::execute(ContextImpl& context, bool includeF
defines["USE_CUTOFF"] = "1"; defines["USE_CUTOFF"] = "1";
if (nb.getUsePeriodic()) if (nb.getUsePeriodic())
defines["USE_PERIODIC"] = "1"; defines["USE_PERIODIC"] = "1";
defines["CUTOFF_SQUARED"] = cl.doubleToString(nb.getCutoffDistance()*nb.getCutoffDistance()); defines["CUTOFF_SQUARED"] = cl.doubleToString(cutoff*cutoff);
defines["CUTOFF"] = cl.doubleToString(nb.getCutoffDistance()); defines["CUTOFF"] = cl.doubleToString(cutoff);
defines["PREFACTOR"] = cl.doubleToString(prefactor); defines["PREFACTOR"] = cl.doubleToString(prefactor);
defines["SURFACE_AREA_FACTOR"] = cl.doubleToString(surfaceAreaFactor); defines["SURFACE_AREA_FACTOR"] = cl.doubleToString(surfaceAreaFactor);
defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms()); defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
...@@ -2856,6 +2854,7 @@ OpenCLCalcCustomGBForceKernel::~OpenCLCalcCustomGBForceKernel() { ...@@ -2856,6 +2854,7 @@ OpenCLCalcCustomGBForceKernel::~OpenCLCalcCustomGBForceKernel() {
void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const CustomGBForce& force) { void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const CustomGBForce& force) {
if (cl.getPlatformData().contexts.size() > 1) if (cl.getPlatformData().contexts.size() > 1)
throw OpenMMException("CustomGBForce does not support using multiple OpenCL devices"); throw OpenMMException("CustomGBForce does not support using multiple OpenCL devices");
cutoff = force.getCutoffDistance();
bool useExclusionsForValue = false; bool useExclusionsForValue = false;
numComputedValues = force.getNumComputedValues(); numComputedValues = force.getNumComputedValues();
vector<string> computedValueNames(force.getNumComputedValues()); vector<string> computedValueNames(force.getNumComputedValues());
...@@ -3047,7 +3046,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3047,7 +3046,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
if (useExclusionsForValue) if (useExclusionsForValue)
pairValueDefines["USE_EXCLUSIONS"] = "1"; pairValueDefines["USE_EXCLUSIONS"] = "1";
pairValueDefines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(cl.getNonbondedUtilities().getForceThreadBlockSize()); pairValueDefines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(cl.getNonbondedUtilities().getForceThreadBlockSize());
pairValueDefines["CUTOFF_SQUARED"] = cl.doubleToString(force.getCutoffDistance()*force.getCutoffDistance()); pairValueDefines["CUTOFF_SQUARED"] = cl.doubleToString(cutoff*cutoff);
pairValueDefines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms()); pairValueDefines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
pairValueDefines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms()); pairValueDefines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
pairValueDefines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks()); pairValueDefines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
...@@ -3240,7 +3239,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3240,7 +3239,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
if (anyExclusions) if (anyExclusions)
pairEnergyDefines["USE_EXCLUSIONS"] = "1"; pairEnergyDefines["USE_EXCLUSIONS"] = "1";
pairEnergyDefines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(cl.getNonbondedUtilities().getForceThreadBlockSize()); pairEnergyDefines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(cl.getNonbondedUtilities().getForceThreadBlockSize());
pairEnergyDefines["CUTOFF_SQUARED"] = cl.doubleToString(force.getCutoffDistance()*force.getCutoffDistance()); pairEnergyDefines["CUTOFF_SQUARED"] = cl.doubleToString(cutoff*cutoff);
pairEnergyDefines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms()); pairEnergyDefines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
pairEnergyDefines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms()); pairEnergyDefines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
pairEnergyDefines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks()); pairEnergyDefines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
...@@ -3492,7 +3491,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3492,7 +3491,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
globals->upload(globalParamValues); globals->upload(globalParamValues);
arguments.push_back(OpenCLNonbondedUtilities::ParameterInfo(prefix+"globals", "float", 1, sizeof(cl_float), globals->getDeviceBuffer())); arguments.push_back(OpenCLNonbondedUtilities::ParameterInfo(prefix+"globals", "float", 1, sizeof(cl_float), globals->getDeviceBuffer()));
} }
cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, force.getNumExclusions() > 0, force.getCutoffDistance(), exclusionList, source, force.getForceGroup()); cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, force.getNumExclusions() > 0, cutoff, exclusionList, source, force.getForceGroup());
for (int i = 0; i < (int) parameters.size(); i++) for (int i = 0; i < (int) parameters.size(); i++)
cl.getNonbondedUtilities().addParameter(parameters[i]); cl.getNonbondedUtilities().addParameter(parameters[i]);
for (int i = 0; i < (int) arguments.size(); i++) for (int i = 0; i < (int) arguments.size(); i++)
...@@ -3527,7 +3526,7 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3527,7 +3526,7 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts; int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts;
pairValueDefines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex); pairValueDefines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex);
pairValueDefines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex); pairValueDefines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex);
pairValueDefines["CUTOFF"] = cl.doubleToString(nb.getCutoffDistance()); pairValueDefines["CUTOFF"] = cl.doubleToString(cutoff);
cl::Program program = cl.createProgram(pairValueSrc, pairValueDefines); cl::Program program = cl.createProgram(pairValueSrc, pairValueDefines);
pairValueKernel = cl::Kernel(program, "computeN2Value"); pairValueKernel = cl::Kernel(program, "computeN2Value");
pairValueSrc = ""; pairValueSrc = "";
...@@ -3541,7 +3540,7 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3541,7 +3540,7 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts; int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts;
pairEnergyDefines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex); pairEnergyDefines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex);
pairEnergyDefines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex); pairEnergyDefines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex);
pairEnergyDefines["CUTOFF"] = cl.doubleToString(nb.getCutoffDistance()); pairEnergyDefines["CUTOFF"] = cl.doubleToString(cutoff);
cl::Program program = cl.createProgram(pairEnergySrc, pairEnergyDefines); cl::Program program = cl.createProgram(pairEnergySrc, pairEnergyDefines);
pairEnergyKernel = cl::Kernel(program, "computeN2Energy"); pairEnergyKernel = cl::Kernel(program, "computeN2Energy");
pairEnergySrc = ""; pairEnergySrc = "";
......
...@@ -220,9 +220,9 @@ __kernel void computeNonbonded( ...@@ -220,9 +220,9 @@ __kernel void computeNonbonded(
if (numTiles <= maxTiles) { if (numTiles <= maxTiles) {
x = tiles[pos]; x = tiles[pos];
real4 blockSizeX = blockSize[x]; real4 blockSizeX = blockSize[x];
singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= CUTOFF && singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= MAX_CUTOFF &&
0.5f*periodicBoxSize.y-blockSizeX.y >= CUTOFF && 0.5f*periodicBoxSize.y-blockSizeX.y >= MAX_CUTOFF &&
0.5f*periodicBoxSize.z-blockSizeX.z >= CUTOFF); 0.5f*periodicBoxSize.z-blockSizeX.z >= MAX_CUTOFF);
} }
else else
#endif #endif
......
...@@ -973,6 +973,62 @@ void testInteractionGroupLongRangeCorrection() { ...@@ -973,6 +973,62 @@ void testInteractionGroupLongRangeCorrection() {
ASSERT_EQUAL_TOL(expected, energy2-energy1, 1e-4); ASSERT_EQUAL_TOL(expected, energy2-energy1, 1e-4);
} }
void testMultipleCutoffs() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
// Add multiple nonbonded forces that have different cutoffs.
CustomNonbondedForce* nonbonded1 = new CustomNonbondedForce("2*r");
nonbonded1->addParticle(vector<double>());
nonbonded1->addParticle(vector<double>());
nonbonded1->setNonbondedMethod(CustomNonbondedForce::CutoffNonPeriodic);
nonbonded1->setCutoffDistance(2.5);
system.addForce(nonbonded1);
CustomNonbondedForce* nonbonded2 = new CustomNonbondedForce("3*r");
nonbonded2->addParticle(vector<double>());
nonbonded2->addParticle(vector<double>());
nonbonded2->setNonbondedMethod(CustomNonbondedForce::CutoffNonPeriodic);
nonbonded2->setCutoffDistance(2.9);
nonbonded2->setForceGroup(1);
system.addForce(nonbonded2);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
positions[1] = Vec3(0, 0, 0);
for (double r = 2.4; r < 3.2; r += 0.2) {
positions[1][1] = r;
context.setPositions(positions);
double e1 = (r < 2.5 ? 2.0*r : 0.0);
double e2 = (r < 2.9 ? 3.0*r : 0.0);
double f1 = (r < 2.5 ? 2.0 : 0.0);
double f2 = (r < 2.9 ? 3.0 : 0.0);
// Check the first force.
State state = context.getState(State::Forces | State::Energy, false, 1);
ASSERT_EQUAL_VEC(Vec3(0, f1, 0), state.getForces()[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, -f1, 0), state.getForces()[1], TOL);
ASSERT_EQUAL_TOL(e1, state.getPotentialEnergy(), TOL);
// Check the second force.
state = context.getState(State::Forces | State::Energy, false, 2);
ASSERT_EQUAL_VEC(Vec3(0, f2, 0), state.getForces()[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, -f2, 0), state.getForces()[1], TOL);
ASSERT_EQUAL_TOL(e2, state.getPotentialEnergy(), TOL);
// Check the sum of both forces.
state = context.getState(State::Forces | State::Energy);
ASSERT_EQUAL_VEC(Vec3(0, f1+f2, 0), state.getForces()[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, -f1-f2, 0), state.getForces()[1], TOL);
ASSERT_EQUAL_TOL(e1+e2, state.getPotentialEnergy(), TOL);
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -997,6 +1053,7 @@ int main(int argc, char* argv[]) { ...@@ -997,6 +1053,7 @@ int main(int argc, char* argv[]) {
testInteractionGroups(); testInteractionGroups();
testLargeInteractionGroup(); testLargeInteractionGroup();
testInteractionGroupLongRangeCorrection(); testInteractionGroupLongRangeCorrection();
testMultipleCutoffs();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment