Unverified Commit 2443dcee authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Common implementation of NonbondedForce (#4922)

* Use common API for kernels

* More code uses common interface

* Bug fixes

* Unified interface for sorting

* Simplified interface for FFT

* Use common event API for synchronization

* Minor changes to make code more consistent between platforms

* Common implementation of NonbondedForce

* Bug fixes

* Flag to enable list of single pairs

* CUDA and OpenCL use common implementation of NonbondedForce

* Fixed compilation error

* HIP uses common implementation of NonbondedForce
parent dfb8d755
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2019 Stanford University and the Authors. * * Portions copyright (c) 2019-2025 Stanford University and the Authors. *
* Portions copyright (c) 2020 Advanced Micro Devices, Inc. * * Portions copyright (c) 2020 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis * * Authors: Peter Eastman, Nicholas Curtis *
* Contributors: * * Contributors: *
...@@ -49,6 +49,11 @@ public: ...@@ -49,6 +49,11 @@ public:
* Block until all operations started before the call to enqueue() have completed. * Block until all operations started before the call to enqueue() have completed.
*/ */
void wait(); void wait();
/**
* Enqueue a barrier that causes a specified ComputeQueue to block until all
* operations started before the call to enqueue() have completed.
*/
void queueWait(ComputeQueue queue);
private: private:
HipContext& context; HipContext& context;
hipEvent_t event; hipEvent_t event;
......
...@@ -48,7 +48,7 @@ class HipContext; ...@@ -48,7 +48,7 @@ class HipContext;
* multiply every value of the original data set by the total number of data points. * multiply every value of the original data set by the total number of data points.
*/ */
class OPENMM_EXPORT_COMMON HipFFT3D : public FFT3D { class OPENMM_EXPORT_COMMON HipFFT3D : public FFT3DImpl {
public: public:
/** /**
* Create an HipFFT3D object for performing transforms of a particular size. * Create an HipFFT3D object for performing transforms of a particular size.
......
...@@ -31,11 +31,12 @@ ...@@ -31,11 +31,12 @@
#include "HipPlatform.h" #include "HipPlatform.h"
#include "HipArray.h" #include "HipArray.h"
#include "HipContext.h" #include "HipContext.h"
#include "HipFFT3D.h"
#include "HipSort.h"
#include "openmm/kernels.h" #include "openmm/kernels.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/common/CommonKernels.h" #include "openmm/common/CommonKernels.h"
#include "openmm/common/CommonCalcNonbondedForce.h"
#include "openmm/common/ComputeSort.h"
#include "openmm/common/FFT3D.h"
namespace OpenMM { namespace OpenMM {
...@@ -86,12 +87,11 @@ private: ...@@ -86,12 +87,11 @@ private:
/** /**
* This kernel is invoked by NonbondedForce to calculate the forces acting on the system. * This kernel is invoked by NonbondedForce to calculate the forces acting on the system.
*/ */
class HipCalcNonbondedForceKernel : public CalcNonbondedForceKernel { class HipCalcNonbondedForceKernel : public CommonCalcNonbondedForceKernel {
public: public:
HipCalcNonbondedForceKernel(std::string name, const Platform& platform, HipContext& cu, const System& system) : CalcNonbondedForceKernel(name, platform), HipCalcNonbondedForceKernel(std::string name, const Platform& platform, HipContext& cu, const System& system) :
cu(cu), hasInitializedFFT(false), sort(NULL), dispersionFft(NULL), fft(NULL), pmeio(NULL), useFixedPointChargeSpreading(false), usePmeStream(false) { CommonCalcNonbondedForceKernel(name, platform, cu, system), cu(cu) {
} }
~HipCalcNonbondedForceKernel();
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
...@@ -99,123 +99,8 @@ public: ...@@ -99,123 +99,8 @@ public:
* @param force the NonbondedForce this kernel will be used for * @param force the NonbondedForce this kernel will be used for
*/ */
void initialize(const System& system, const NonbondedForce& force); void initialize(const System& system, const NonbondedForce& force);
/**
* Execute the kernel to calculate the forces and/or energy.
*
* @param context the context in which to execute this kernel
* @param includeForces true if forces should be calculated
* @param includeEnergy true if the energy should be calculated
* @param includeDirect true if direct space interactions should be included
* @param includeReciprocal true if reciprocal space interactions should be included
* @return the potential energy due to the force
*/
double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the NonbondedForce to copy the parameters from
* @param firstParticle the index of the first particle whose parameters might have changed
* @param lastParticle the index of the last particle whose parameters might have changed
* @param firstException the index of the first exception whose parameters might have changed
* @param lastException the index of the last exception whose parameters might have changed
*/
void copyParametersToContext(ContextImpl& context, const NonbondedForce& force, int firstParticle, int lastParticle, int firstException, int lastException);
/**
* Get the parameters being used for PME.
*
* @param alpha the separation parameter
* @param nx the number of grid points along the X axis
* @param ny the number of grid points along the Y axis
* @param nz the number of grid points along the Z axis
*/
void getPMEParameters(double& alpha, int& nx, int& ny, int& nz) const;
/**
* Get the dispersion parameters being used for the dispersion term in LJPME.
*
* @param alpha the separation parameter
* @param nx the number of grid points along the X axis
* @param ny the number of grid points along the Y axis
* @param nz the number of grid points along the Z axis
*/
void getLJPMEParameters(double& alpha, int& nx, int& ny, int& nz) const;
private: private:
class SortTrait : public HipSort::SortTrait {
int getDataSize() const {return 8;}
int getKeySize() const {return 4;}
const char* getDataType() const {return "int2";}
const char* getKeyType() const {return "int";}
const char* getMinKey() const {return "(-2147483647-1)";}
const char* getMaxKey() const {return "2147483647";}
const char* getMaxValue() const {return "make_int2(2147483647, 2147483647)";}
const char* getSortKey() const {return "value.y";}
};
class ForceInfo;
class PmeIO;
class PmePreComputation;
class PmePostComputation;
class SyncStreamPreComputation;
class SyncStreamPostComputation;
HipContext& cu; HipContext& cu;
ForceInfo* info;
bool hasInitializedFFT;
HipArray charges;
HipArray sigmaEpsilon;
HipArray exceptionParams;
HipArray exclusionAtoms;
HipArray exclusionParams;
HipArray baseParticleParams;
HipArray baseExceptionParams;
HipArray particleParamOffsets;
HipArray exceptionParamOffsets;
HipArray particleOffsetIndices;
HipArray exceptionOffsetIndices;
HipArray globalParams;
HipArray cosSinSums;
HipArray pmeGrid1;
HipArray pmeGrid2;
HipArray pmeBsplineModuliX;
HipArray pmeBsplineModuliY;
HipArray pmeBsplineModuliZ;
HipArray pmeDispersionBsplineModuliX;
HipArray pmeDispersionBsplineModuliY;
HipArray pmeDispersionBsplineModuliZ;
HipArray pmeAtomGridIndex;
HipArray pmeEnergyBuffer;
HipArray chargeBuffer;
HipSort* sort;
Kernel cpuPme;
PmeIO* pmeio;
ComputeQueue pmeQueue;
hipEvent_t pmeSyncEvent, paramsSyncEvent;
HipFFT3D* fft;
HipFFT3D* dispersionFft;
hipFunction_t computeParamsKernel, computeExclusionParamsKernel, computePlasmaCorrectionKernel;
hipFunction_t ewaldSumsKernel;
hipFunction_t ewaldForcesKernel;
hipFunction_t pmeGridIndexKernel;
hipFunction_t pmeDispersionGridIndexKernel;
hipFunction_t pmeSpreadChargeKernel;
hipFunction_t pmeDispersionSpreadChargeKernel;
hipFunction_t pmeFinishSpreadChargeKernel;
hipFunction_t pmeDispersionFinishSpreadChargeKernel;
hipFunction_t pmeEvalEnergyKernel;
hipFunction_t pmeEvalDispersionEnergyKernel;
hipFunction_t pmeConvolutionKernel;
hipFunction_t pmeDispersionConvolutionKernel;
hipFunction_t pmeInterpolateForceKernel;
hipFunction_t pmeInterpolateDispersionForceKernel;
std::vector<std::pair<int, int> > exceptionAtoms;
std::vector<std::string> paramNames;
std::vector<double> paramValues;
std::map<int, int> exceptionIndex;
double ewaldSelfEnergy, dispersionCoefficient, alpha, dispersionAlpha, totalCharge;
int interpolateForceThreads;
int gridSizeX, gridSizeY, gridSizeZ;
int dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ;
bool hasCoulomb, hasLJ, useFixedPointChargeSpreading, usePmeStream, doLJPME, usePosqCharges, recomputeParams, hasOffsets;
NonbondedMethod nonbondedMethod;
static const int PmeOrder = 5;
}; };
/** /**
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "openmm/System.h" #include "openmm/System.h"
#include "HipArray.h" #include "HipArray.h"
#include "HipExpressionUtilities.h" #include "HipExpressionUtilities.h"
#include "openmm/common/ComputeSort.h"
#include "openmm/common/NonbondedUtilities.h" #include "openmm/common/NonbondedUtilities.h"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <sstream> #include <sstream>
...@@ -40,7 +41,6 @@ ...@@ -40,7 +41,6 @@
namespace OpenMM { namespace OpenMM {
class HipContext; class HipContext;
class HipSort;
/** /**
* This class provides a generic interface for calculating nonbonded interactions. It does this in two * This class provides a generic interface for calculating nonbonded interactions. It does this in two
...@@ -72,20 +72,6 @@ public: ...@@ -72,20 +72,6 @@ public:
class ParameterInfo; class ParameterInfo;
HipNonbondedUtilities(HipContext& context); HipNonbondedUtilities(HipContext& context);
~HipNonbondedUtilities(); ~HipNonbondedUtilities();
/**
* Add a nonbonded interaction to be evaluated by the default interaction kernel.
*
* @param usesCutoff specifies whether a cutoff should be applied to this interaction
* @param usesPeriodic specifies whether periodic boundary conditions should be applied to this interaction
* @param usesExclusions specifies whether this interaction uses exclusions. If this is true, it must have identical exclusions to every other interaction.
* @param cutoffDistance the cutoff distance for this interaction (ignored if usesCutoff is false)
* @param exclusionList for each atom, specifies the list of other atoms whose interactions should be excluded
* @param kernel the code to evaluate the interaction
* @param forceGroup the force group in which the interaction should be calculated
* @param usesNeighborList specifies whether a neighbor list should be used to optimize this interaction. This should
* be viewed as only a suggestion. Even when it is false, a neighbor list may be used anyway.
*/
void addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const std::vector<std::vector<int> >& exclusionList, const std::string& kernel, int forceGroup, bool usesNeighborList = true);
/** /**
* Add a nonbonded interaction to be evaluated by the default interaction kernel. * Add a nonbonded interaction to be evaluated by the default interaction kernel.
* *
...@@ -100,7 +86,9 @@ public: ...@@ -100,7 +86,9 @@ public:
* be viewed as only a suggestion. Even when it is false, a neighbor list may be used anyway. * be viewed as only a suggestion. Even when it is false, a neighbor list may be used anyway.
* @param supportsPairList specifies whether this interaction can work with a neighbor list that uses a separate pair list * @param supportsPairList specifies whether this interaction can work with a neighbor list that uses a separate pair list
*/ */
void addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const std::vector<std::vector<int> >& exclusionList, const std::string& kernel, int forceGroup, bool usesNeighborList, bool supportsPairList); void addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance,
const std::vector<std::vector<int> >& exclusionList, const std::string& kernel,
int forceGroup, bool useNeighborList=true, bool supportsPairList=false);
/** /**
* Add a per-atom parameter that the default interaction kernel may depend on. * Add a per-atom parameter that the default interaction kernel may depend on.
*/ */
...@@ -344,7 +332,7 @@ private: ...@@ -344,7 +332,7 @@ private:
HipArray largeBlockBoundingBox; HipArray largeBlockBoundingBox;
HipArray oldPositions; HipArray oldPositions;
HipArray rebuildNeighborList; HipArray rebuildNeighborList;
HipSort* blockSorter; ComputeSort blockSorter;
hipEvent_t downloadCountEvent; hipEvent_t downloadCountEvent;
unsigned int* pinnedCountBuffer; unsigned int* pinnedCountBuffer;
std::vector<void*> forceArgs, findBlockBoundsArgs, computeSortKeysArgs, sortBoxDataArgs, findInteractingBlocksArgs, copyInteractionCountsArgs; std::vector<void*> forceArgs, findBlockBoundsArgs, computeSortKeysArgs, sortBoxDataArgs, findInteractingBlocksArgs, copyInteractionCountsArgs;
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2018 Stanford University and the Authors. * * Portions copyright (c) 2010-2025 Stanford University and the Authors. *
* Portions copyright (c) 2020-2023 Advanced Micro Devices, Inc. * * Portions copyright (c) 2020-2023 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis * * Authors: Peter Eastman, Nicholas Curtis *
* Contributors: * * Contributors: *
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "HipArray.h" #include "HipArray.h"
#include "openmm/common/ComputeSort.h"
#include "openmm/common/windowsExportCommon.h" #include "openmm/common/windowsExportCommon.h"
#include "HipContext.h" #include "HipContext.h"
...@@ -42,7 +43,7 @@ namespace OpenMM { ...@@ -42,7 +43,7 @@ namespace OpenMM {
* sort and the key for sorting it. Here is an example of a trait class for * sort and the key for sorting it. Here is an example of a trait class for
* sorting floats: * sorting floats:
* *
* class FloatTrait : public HipSort::SortTrait { * class FloatTrait : public ComputeSortImpl::SortTrait {
* int getDataSize() const {return 4;} * int getDataSize() const {return 4;}
* int getKeySize() const {return 4;} * int getKeySize() const {return 4;}
* const char* getDataType() const {return "float";} * const char* getDataType() const {return "float";}
...@@ -67,7 +68,7 @@ namespace OpenMM { ...@@ -67,7 +68,7 @@ namespace OpenMM {
* elements). * elements).
*/ */
class OPENMM_EXPORT_COMMON HipSort { class OPENMM_EXPORT_COMMON HipSort : public ComputeSortImpl {
public: public:
class SortTrait; class SortTrait;
/** /**
...@@ -81,15 +82,15 @@ public: ...@@ -81,15 +82,15 @@ public:
* @param uniform whether the input data is expected to follow a uniform or nonuniform * @param uniform whether the input data is expected to follow a uniform or nonuniform
* distribution. This argument is used only as a hint. * distribution. This argument is used only as a hint.
*/ */
HipSort(HipContext& context, SortTrait* trait, unsigned int length, bool uniform=true); HipSort(HipContext& context, ComputeSortImpl::SortTrait* trait, unsigned int length, bool uniform=true);
~HipSort(); ~HipSort();
/** /**
* Sort an array. * Sort an array.
*/ */
void sort(HipArray& data); void sort(ArrayInterface& data);
private: private:
HipContext& context; HipContext& context;
SortTrait* trait; ComputeSortImpl::SortTrait* trait;
HipArray counters; HipArray counters;
HipArray dataRange; HipArray dataRange;
HipArray bucketOfElement; HipArray bucketOfElement;
...@@ -101,48 +102,6 @@ private: ...@@ -101,48 +102,6 @@ private:
bool isShortList, uniform; bool isShortList, uniform;
}; };
/**
* A subclass of SortTrait defines the type of value to sort, and the key for sorting them.
*/
class HipSort::SortTrait {
public:
virtual ~SortTrait() {
}
/**
* Get the size of each data value in bytes.
*/
virtual int getDataSize() const = 0;
/**
* Get the size of each key value in bytes.
*/
virtual int getKeySize() const = 0;
/**
* Get the data type of the values to sort.
*/
virtual const char* getDataType() const = 0;
/**
* Get the data type of the sorting key.
*/
virtual const char* getKeyType() const = 0;
/**
* Get the minimum value a key can take.
*/
virtual const char* getMinKey() const = 0;
/**
* Get the maximum value a key can take.
*/
virtual const char* getMaxKey() const = 0;
/**
* Get a value whose key is guaranteed to equal getMaxKey().
*/
virtual const char* getMaxValue() const = 0;
/**
* Get the HIP code to select the key from the data value.
*/
virtual const char* getSortKey() const = 0;
};
} // namespace OpenMM } // namespace OpenMM
#endif // __OPENMM_HIPSORT_H__ #endif // __OPENMM_HIPSORT_H__
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "HipNonbondedUtilities.h" #include "HipNonbondedUtilities.h"
#include "HipProgram.h" #include "HipProgram.h"
#include "HipQueue.h" #include "HipQueue.h"
#include "HipSort.h"
#include "openmm/common/ComputeArray.h" #include "openmm/common/ComputeArray.h"
#include "openmm/common/ContextSelector.h" #include "openmm/common/ContextSelector.h"
#include "SHA1.h" #include "SHA1.h"
...@@ -691,8 +692,12 @@ ComputeEvent HipContext::createEvent() { ...@@ -691,8 +692,12 @@ ComputeEvent HipContext::createEvent() {
return shared_ptr<ComputeEventImpl>(new HipEvent(*this)); return shared_ptr<ComputeEventImpl>(new HipEvent(*this));
} }
HipFFT3D* HipContext::createFFT(int xsize, int ysize, int zsize, bool realToComplex) { ComputeSort HipContext::createSort(ComputeSortImpl::SortTrait* trait, unsigned int length, bool uniform) {
return new HipFFT3D(*this, xsize, ysize, zsize, realToComplex); return shared_ptr<ComputeSortImpl>(new HipSort(*this, trait, length, uniform));
}
FFT3D HipContext::createFFT(int xsize, int ysize, int zsize, bool realToComplex) {
return FFT3D(new HipFFT3D(*this, xsize, ysize, zsize, realToComplex));
} }
int HipContext::findLegalFFTDimension(int minimum) { int HipContext::findLegalFFTDimension(int minimum) {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2019 Stanford University and the Authors. * * Portions copyright (c) 2019-2025 Stanford University and the Authors. *
* Portions copyright (c) 2020-2023 Advanced Micro Devices, Inc. * * Portions copyright (c) 2020-2023 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis * * Authors: Peter Eastman, Nicholas Curtis *
* Contributors: * * Contributors: *
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "HipEvent.h" #include "HipEvent.h"
#include "HipQueue.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
using namespace OpenMM; using namespace OpenMM;
...@@ -49,3 +50,7 @@ void HipEvent::enqueue() { ...@@ -49,3 +50,7 @@ void HipEvent::enqueue() {
void HipEvent::wait() { void HipEvent::wait() {
hipEventSynchronize(event); hipEventSynchronize(event);
} }
void HipEvent::queueWait(ComputeQueue queue) {
hipStreamWaitEvent(dynamic_cast<HipQueue*>(queue.get())->getStream(), event, 0);
}
This diff is collapsed.
...@@ -31,7 +31,6 @@ ...@@ -31,7 +31,6 @@
#include "HipContext.h" #include "HipContext.h"
#include "HipKernelSources.h" #include "HipKernelSources.h"
#include "HipExpressionUtilities.h" #include "HipExpressionUtilities.h"
#include "HipSort.h"
#include <algorithm> #include <algorithm>
#include <map> #include <map>
#include <set> #include <set>
...@@ -48,7 +47,7 @@ using namespace std; ...@@ -48,7 +47,7 @@ using namespace std;
} }
class HipNonbondedUtilities::BlockSortTrait : public HipSort::SortTrait { class HipNonbondedUtilities::BlockSortTrait : public ComputeSortImpl::SortTrait {
public: public:
BlockSortTrait() {} BlockSortTrait() {}
int getDataSize() const {return sizeof(int);} int getDataSize() const {return sizeof(int);}
...@@ -62,7 +61,7 @@ public: ...@@ -62,7 +61,7 @@ public:
}; };
HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(context), useCutoff(false), usePeriodic(false), useNeighborList(false), anyExclusions(false), usePadding(true), HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(context), useCutoff(false), usePeriodic(false), useNeighborList(false), anyExclusions(false), usePadding(true),
blockSorter(NULL), pinnedCountBuffer(NULL), forceRebuildNeighborList(true), groupFlags(0), canUsePairList(true), tilesAfterReorder(0) { pinnedCountBuffer(NULL), forceRebuildNeighborList(true), groupFlags(0), canUsePairList(true), tilesAfterReorder(0) {
// Decide how many thread blocks to use. // Decide how many thread blocks to use.
string errorMessage = "Error initializing nonbonded utilities"; string errorMessage = "Error initializing nonbonded utilities";
...@@ -82,18 +81,13 @@ HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(cont ...@@ -82,18 +81,13 @@ HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(cont
} }
HipNonbondedUtilities::~HipNonbondedUtilities() { HipNonbondedUtilities::~HipNonbondedUtilities() {
if (blockSorter != NULL)
delete blockSorter;
if (pinnedCountBuffer != NULL) if (pinnedCountBuffer != NULL)
hipHostFree(pinnedCountBuffer); hipHostFree(pinnedCountBuffer);
hipEventDestroy(downloadCountEvent); hipEventDestroy(downloadCountEvent);
} }
void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup, bool usesNeighborList) { void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance,
addInteraction(usesCutoff, usesPeriodic, usesExclusions, cutoffDistance, exclusionList, kernel, forceGroup, usesNeighborList, false); const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup, bool usesNeighborList, bool supportsPairList) {
}
void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup, bool usesNeighborList, bool supportsPairList) {
if (groupCutoff.size() > 0) { if (groupCutoff.size() > 0) {
if (usesCutoff != useCutoff) if (usesCutoff != useCutoff)
throw OpenMMException("All Forces must agree on whether to use a cutoff"); throw OpenMMException("All Forces must agree on whether to use a cutoff");
...@@ -304,7 +298,7 @@ void HipNonbondedUtilities::initialize(const System& system) { ...@@ -304,7 +298,7 @@ void HipNonbondedUtilities::initialize(const System& system) {
largeBlockBoundingBox.initialize(context, numAtomBlocks*4, elementSize, "largeBlockBoundingBox"); largeBlockBoundingBox.initialize(context, numAtomBlocks*4, elementSize, "largeBlockBoundingBox");
oldPositions.initialize(context, numAtoms, 4*elementSize, "oldPositions"); oldPositions.initialize(context, numAtoms, 4*elementSize, "oldPositions");
rebuildNeighborList.initialize<int>(context, 1, "rebuildNeighborList"); rebuildNeighborList.initialize<int>(context, 1, "rebuildNeighborList");
blockSorter = new HipSort(context, new BlockSortTrait(), numAtomBlocks, false); blockSorter = context.createSort(new BlockSortTrait(), numAtomBlocks, false);
vector<unsigned int> count(2, 0); vector<unsigned int> count(2, 0);
interactionCount.upload(count); interactionCount.upload(count);
rebuildNeighborList.upload(&count[0]); rebuildNeighborList.upload(&count[0]);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment