Commit f11a6aea authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bugs in OpenCL implementation of CustomGBForce

parent 09777f85
...@@ -1298,7 +1298,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1298,7 +1298,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float4", sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float4", sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
tableArgs << ", __global float4* arrayName"; tableArgs << ", __global float4* " << arrayName;
} }
if (force.getNumFunctions() > 0) { if (force.getNumFunctions() > 0) {
tabulatedFunctionParams = new OpenCLArray<mm_float4>(cl, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters", false, CL_MEM_READ_ONLY); tabulatedFunctionParams = new OpenCLArray<mm_float4>(cl, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters", false, CL_MEM_READ_ONLY);
...@@ -1709,6 +1709,8 @@ void OpenCLCalcCustomGBForceKernel::executeForces(ContextImpl& context) { ...@@ -1709,6 +1709,8 @@ void OpenCLCalcCustomGBForceKernel::executeForces(ContextImpl& context) {
int index = 0; int index = 0;
pairValueKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer()); pairValueKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
pairValueKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL); pairValueKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL);
pairValueKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusions().getDeviceBuffer());
pairValueKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusionIndices().getDeviceBuffer());
pairValueKernel.setArg<cl::Buffer>(index++, valueBuffers->getDeviceBuffer()); pairValueKernel.setArg<cl::Buffer>(index++, valueBuffers->getDeviceBuffer());
pairValueKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float), NULL); pairValueKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float), NULL);
pairValueKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float), NULL); pairValueKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float), NULL);
...@@ -1754,6 +1756,8 @@ void OpenCLCalcCustomGBForceKernel::executeForces(ContextImpl& context) { ...@@ -1754,6 +1756,8 @@ void OpenCLCalcCustomGBForceKernel::executeForces(ContextImpl& context) {
pairEnergyKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL); pairEnergyKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL);
pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer()); pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
pairEnergyKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL); pairEnergyKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL);
pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusions().getDeviceBuffer());
pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusionIndices().getDeviceBuffer());
pairEnergyKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL); pairEnergyKernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL);
if (nb.getUseCutoff()) { if (nb.getUseCutoff()) {
pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer()); pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
......
...@@ -170,6 +170,18 @@ public: ...@@ -170,6 +170,18 @@ public:
OpenCLArray<cl_uint>& getInteractionFlags() { OpenCLArray<cl_uint>& getInteractionFlags() {
return *interactionFlags; return *interactionFlags;
} }
/**
* Get the array containing exclusion flags.
*/
OpenCLArray<cl_uint>& getExclusions() {
return *exclusions;
}
/**
* Get the array containing the index into the exclusion array for each tile.
*/
OpenCLArray<cl_uint>& getExclusionIndices() {
return *exclusionIndex;
}
/** /**
* Create a Kernel for evaluating a nonbonded interaction. Cutoffs and periodic boundary conditions * Create a Kernel for evaluating a nonbonded interaction. Cutoffs and periodic boundary conditions
* are assumed to be the same as those for the default interaction Kernel, since this kernel will use * are assumed to be the same as those for the default interaction Kernel, since this kernel will use
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
*/ */
__kernel void computeN2Energy(__global float4* forceBuffers, __global float* energyBuffer, __local float4* local_force, __kernel void computeN2Energy(__global float4* forceBuffers, __global float* energyBuffer, __local float4* local_force,
__global float4* posq, __local float4* local_posq, __local float4* tempBuffer, __global unsigned int* tiles, __global float4* posq, __local float4* local_posq, __global unsigned int* exclusions, __global unsigned int* exclusionIndices,
__local float4* tempBuffer, __global unsigned int* tiles,
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
__global unsigned int* interactionFlags, __global unsigned int* interactionCount __global unsigned int* interactionFlags, __global unsigned int* interactionCount
#else #else
...@@ -56,7 +57,11 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene ...@@ -56,7 +57,11 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene
delta.y -= floor(delta.y/PERIODIC_BOX_SIZE_Y+0.5f)*PERIODIC_BOX_SIZE_Y; delta.y -= floor(delta.y/PERIODIC_BOX_SIZE_Y+0.5f)*PERIODIC_BOX_SIZE_Y;
delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z; delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z;
#endif #endif
float r = sqrt(delta.x*delta.x + delta.y*delta.y + delta.z*delta.z); float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
if (r2 < CUTOFF_SQUARED) {
#endif
float r = sqrt(r2);
LOAD_ATOM2_PARAMETERS LOAD_ATOM2_PARAMETERS
atom2 = y+j; atom2 = y+j;
float dEdR = 0.0f; float dEdR = 0.0f;
...@@ -68,6 +73,9 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene ...@@ -68,6 +73,9 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene
energy += 0.5f*tempEnergy; energy += 0.5f*tempEnergy;
delta.xyz *= dEdR; delta.xyz *= dEdR;
force.xyz -= delta.xyz; force.xyz -= delta.xyz;
#ifdef USE_CUTOFF
}
#endif
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
excl >>= 1; excl >>= 1;
#endif #endif
...@@ -121,7 +129,11 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene ...@@ -121,7 +129,11 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene
delta.y -= floor(delta.y/PERIODIC_BOX_SIZE_Y+0.5f)*PERIODIC_BOX_SIZE_Y; delta.y -= floor(delta.y/PERIODIC_BOX_SIZE_Y+0.5f)*PERIODIC_BOX_SIZE_Y;
delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z; delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z;
#endif #endif
float r = sqrt(delta.x*delta.x + delta.y*delta.y + delta.z*delta.z); float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
if (r2 < CUTOFF_SQUARED) {
#endif
float r = sqrt(r2);
LOAD_ATOM2_PARAMETERS LOAD_ATOM2_PARAMETERS
atom2 = y+tj; atom2 = y+tj;
float dEdR = 0.0f; float dEdR = 0.0f;
...@@ -136,6 +148,9 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene ...@@ -136,6 +148,9 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene
atom2 = tbx+tj; atom2 = tbx+tj;
local_force[atom2].xyz += delta.xyz; local_force[atom2].xyz += delta.xyz;
RECORD_DERIVATIVE_2 RECORD_DERIVATIVE_2
#ifdef USE_CUTOFF
}
#endif
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
excl >>= 1; excl >>= 1;
#endif #endif
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* Compute a value based on pair interactions. * Compute a value based on pair interactions.
*/ */
__kernel void computeN2Value(__global float4* posq, __local float4* local_posq, __global float* global_value, __kernel void computeN2Value(__global float4* posq, __local float4* local_posq, __global unsigned int* exclusions,
__local float* local_value, __local float* tempBuffer, __global unsigned int* tiles, __global unsigned int* exclusionIndices, __global float* global_value, __local float* local_value,
__local float* tempBuffer, __global unsigned int* tiles,
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
__global unsigned int* interactionFlags, __global unsigned int* interactionCount __global unsigned int* interactionFlags, __global unsigned int* interactionCount
#else #else
...@@ -57,6 +58,9 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq, ...@@ -57,6 +58,9 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq,
delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z; delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z;
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
if (r2 < CUTOFF_SQUARED) {
#endif
float r = sqrt(r2); float r = sqrt(r2);
LOAD_ATOM2_PARAMETERS LOAD_ATOM2_PARAMETERS
atom2 = y+j; atom2 = y+j;
...@@ -70,6 +74,9 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq, ...@@ -70,6 +74,9 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq,
COMPUTE_VALUE COMPUTE_VALUE
} }
value += tempValue1; value += tempValue1;
#ifdef USE_CUTOFF
}
#endif
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
excl >>= 1; excl >>= 1;
#endif #endif
...@@ -112,15 +119,17 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq, ...@@ -112,15 +119,17 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq,
delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z; delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z;
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
float tempValue1 = 0.0f;
float tempValue2 = 0.0f;
if (r2 < CUTOFF_SQUARED) {
float r = sqrt(r2); float r = sqrt(r2);
LOAD_ATOM2_PARAMETERS LOAD_ATOM2_PARAMETERS
atom2 = y+j; atom2 = y+j;
float tempValue1 = 0.0f;
float tempValue2 = 0.0f;
if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) { if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
COMPUTE_VALUE COMPUTE_VALUE
} }
value += tempValue1; value += tempValue1;
}
tempBuffer[get_local_id(0)] = tempValue2; tempBuffer[get_local_id(0)] = tempValue2;
// Sum the forces on atom2. // Sum the forces on atom2.
...@@ -165,6 +174,9 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq, ...@@ -165,6 +174,9 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq,
delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z; delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z;
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
if (r2 < CUTOFF_SQUARED) {
#endif
float r = sqrt(r2); float r = sqrt(r2);
LOAD_ATOM2_PARAMETERS LOAD_ATOM2_PARAMETERS
atom2 = y+tj; atom2 = y+tj;
...@@ -179,6 +191,9 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq, ...@@ -179,6 +191,9 @@ __kernel void computeN2Value(__global float4* posq, __local float4* local_posq,
} }
value += tempValue1; value += tempValue1;
local_value[tbx+tj] += tempValue2; local_value[tbx+tj] += tempValue2;
#ifdef USE_CUTOFF
}
#endif
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
excl >>= 1; excl >>= 1;
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment