Commit 8af0ac1c authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bugs in OpenCL CustomGBForce

parent bd9b78ba
...@@ -1494,7 +1494,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1494,7 +1494,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
} }
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_INTERACTION"] = n2EnergySource.str(); replacements["COMPUTE_INTERACTION"] = n2EnergySource.str();
stringstream extraArgs, loadLocal1, loadLocal2, load1, load2, recordDeriv, storeDerivs1, storeDerivs2, declareTemps, setTemps; stringstream extraArgs, loadLocal1, loadLocal2, clearLocal, load1, load2, recordDeriv, storeDerivs1, storeDerivs2, declareTemps, setTemps;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", __constant float* globals"; extraArgs << ", __constant float* globals";
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
...@@ -1519,9 +1519,9 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1519,9 +1519,9 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i]; const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
string index = intToString(i+1); string index = intToString(i+1);
extraArgs << ", __global " << buffer.getType() << "* derivBuffers" << index << ", __local " << buffer.getType() << "* local_deriv" << index; extraArgs << ", __global " << buffer.getType() << "* derivBuffers" << index << ", __local " << buffer.getType() << "* local_deriv" << index;
loadLocal2 << "local_deriv" << index << "[get_local_id(0)] = 0.0f;\n"; clearLocal << "local_deriv" << index << "[get_local_id(0)] = 0.0f;\n";
load1 << buffer.getType() << " deriv" << index << "_1 = 0;\n"; load1 << buffer.getType() << " deriv" << index << "_1 = 0.0f;\n";
load2 << buffer.getType() << " deriv" << index << "_2 = 0;\n"; load2 << buffer.getType() << " deriv" << index << "_2 = 0.0f;\n";
recordDeriv << "local_deriv" << index << "[atom2] += deriv" << index << "_2;\n"; recordDeriv << "local_deriv" << index << "[atom2] += deriv" << index << "_2;\n";
storeDerivs1 << "STORE_DERIVATIVE_1(" << index << ")"; storeDerivs1 << "STORE_DERIVATIVE_1(" << index << ")";
storeDerivs2 << "STORE_DERIVATIVE_2(" << index << ")"; storeDerivs2 << "STORE_DERIVATIVE_2(" << index << ")";
...@@ -1531,6 +1531,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1531,6 +1531,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
replacements["CLEAR_LOCAL_DERIVATIVES"] = clearLocal.str();
replacements["LOAD_ATOM1_PARAMETERS"] = load1.str(); replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
replacements["LOAD_ATOM2_PARAMETERS"] = load2.str(); replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
replacements["RECORD_DERIVATIVE_2"] = recordDeriv.str(); replacements["RECORD_DERIVATIVE_2"] = recordDeriv.str();
......
...@@ -110,6 +110,7 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene ...@@ -110,6 +110,7 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene
LOAD_LOCAL_PARAMETERS_FROM_GLOBAL LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
} }
local_force[get_local_id(0)] = 0.0f; local_force[get_local_id(0)] = 0.0f;
CLEAR_LOCAL_DERIVATIVES
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
// Compute the full set of interactions in this tile. // Compute the full set of interactions in this tile.
...@@ -151,7 +152,7 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene ...@@ -151,7 +152,7 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene
energy += tempEnergy; energy += tempEnergy;
delta.xyz *= dEdR; delta.xyz *= dEdR;
force.xyz -= delta.xyz; force.xyz -= delta.xyz;
atom2 = baseLocalAtom+tj; atom2 = baseLocalAtom+tj+forceBufferOffset;
local_force[baseLocalAtom+tj+forceBufferOffset].xyz += delta.xyz; local_force[baseLocalAtom+tj+forceBufferOffset].xyz += delta.xyz;
RECORD_DERIVATIVE_2 RECORD_DERIVATIVE_2
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
......
...@@ -101,6 +101,7 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene ...@@ -101,6 +101,7 @@ __kernel void computeN2Energy(__global float4* forceBuffers, __global float* ene
LOAD_LOCAL_PARAMETERS_FROM_GLOBAL LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
} }
local_force[get_local_id(0)] = 0.0f; local_force[get_local_id(0)] = 0.0f;
CLEAR_LOCAL_DERIVATIVES
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
unsigned int flags = interactionFlags[pos]; unsigned int flags = interactionFlags[pos];
if (!hasExclusions && flags == 0) { if (!hasExclusions && flags == 0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment