Commit b90354c9 authored by Peter Eastman's avatar Peter Eastman
Browse files

CUDA and OpenCL implementation of parameter derivatives for CustomCentroidBondForce

parent ccb2000d
......@@ -979,6 +979,7 @@ public:
private:
int numGroups, numBonds;
bool needEnergyParamDerivs;
CudaContext& cu;
CudaParameterSet* params;
CudaArray* globals;
......
......@@ -4693,6 +4693,9 @@ void CudaCalcCustomCentroidBondForceKernel::initialize(const System& system, con
groupForcesArgs.push_back(cu.getPeriodicBoxVecXPointer());
groupForcesArgs.push_back(cu.getPeriodicBoxVecYPointer());
groupForcesArgs.push_back(cu.getPeriodicBoxVecZPointer());
needEnergyParamDerivs = (force.getNumEnergyParameterDerivatives() > 0);
if (needEnergyParamDerivs)
groupForcesArgs.push_back(NULL); // Derivatives buffer hasn't been created yet
// Record the tabulated functions.
......@@ -4736,6 +4739,8 @@ void CudaCalcCustomCentroidBondForceKernel::initialize(const System& system, con
const string& name = force.getPerBondParameterName(i);
variables[name] = "bondParams"+params->getParameterSuffix(i);
}
if (needEnergyParamDerivs)
extraArgs << ", mixed* __restrict__ energyParamDerivs";
if (force.getNumGlobalParameters() > 0) {
globals = CudaArray::create<float>(cu, force.getNumGlobalParameters(), "customCentroidBondGlobals");
globals->upload(globalParamValues);
......@@ -4763,7 +4768,7 @@ void CudaCalcCustomCentroidBondForceKernel::initialize(const System& system, con
atomNames.push_back("P"+index);
posNames.push_back("pos"+index);
}
stringstream compute;
stringstream compute, initParamDerivs, saveParamDerivs;
for (int i = 0; i < groupsPerBond; i++) {
compute<<"int group"<<(i+1)<<" = bondGroups[index+"<<(i*numBonds)<<"];\n";
compute<<"real4 pos"<<(i+1)<<" = centerPositions[group"<<(i+1)<<"];\n";
......@@ -4836,6 +4841,21 @@ void CudaCalcCustomCentroidBondForceKernel::initialize(const System& system, con
groupForcesArgs.push_back(&buffer.getMemory());
}
forceExpressions["energy += "] = energyExpression;
if (needEnergyParamDerivs) {
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string paramName = force.getEnergyParameterDerivativeName(i);
cu.addEnergyParameterDerivative(paramName);
Lepton::ParsedExpression derivExpression = energyExpression.differentiate(paramName).optimize();
forceExpressions[string("energyParamDeriv")+cu.intToString(i)+" += "] = derivExpression;
initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
}
const vector<string>& allParamDerivNames = cu.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
saveParamDerivs << "energyParamDerivs[(blockIdx.x*blockDim.x+threadIdx.x)*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
}
compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
// Finally, apply forces to groups.
......@@ -4925,6 +4945,8 @@ void CudaCalcCustomCentroidBondForceKernel::initialize(const System& system, con
replacements["PADDED_NUM_ATOMS"] = cu.intToString(cu.getPaddedNumAtoms());
replacements["EXTRA_ARGS"] = extraArgs.str();
replacements["COMPUTE_FORCE"] = compute.str();
replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
CUmodule module = cu.createModule(CudaKernelSources::vectorOps+cu.replaceStrings(CudaKernelSources::customCentroidBond, replacements));
computeCentersKernel = cu.getKernel(module, "computeGroupCenters");
groupForcesKernel = cu.getKernel(module, "computeGroupForces");
......@@ -4949,6 +4971,8 @@ double CudaCalcCustomCentroidBondForceKernel::execute(ContextImpl& context, bool
&groupOffsets->getDevicePointer(), &centerPositions->getDevicePointer()};
cu.executeKernel(computeCentersKernel, computeCentersArgs, CudaContext::TileSize*numGroups);
groupForcesArgs[1] = &cu.getEnergyBuffer().getDevicePointer();
if (needEnergyParamDerivs)
groupForcesArgs[9] = &cu.getEnergyParamDerivBuffer().getDevicePointer();
cu.executeKernel(groupForcesKernel, &groupForcesArgs[0], numBonds);
void* applyForcesArgs[] = {&groupParticles->getDevicePointer(), &groupWeights->getDevicePointer(), &groupOffsets->getDevicePointer(),
&groupForces->getDevicePointer(), &cu.getForce().getDevicePointer()};
......
......@@ -111,10 +111,12 @@ extern "C" __global__ void computeGroupForces(unsigned long long* __restrict__ g
const int* __restrict__ bondGroups, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ
EXTRA_ARGS) {
mixed energy = 0;
INIT_PARAM_DERIVS
for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < NUM_BONDS; index += blockDim.x*gridDim.x) {
COMPUTE_FORCE
}
energyBuffer[blockIdx.x*blockDim.x+threadIdx.x] += energy;
SAVE_PARAM_DERIVS
}
/**
......
......@@ -962,6 +962,7 @@ public:
private:
int numGroups, numBonds;
bool needEnergyParamDerivs;
OpenCLContext& cl;
OpenCLParameterSet* params;
OpenCLArray* globals;
......
......@@ -4933,6 +4933,9 @@ void OpenCLCalcCustomCentroidBondForceKernel::initialize(const System& system, c
const string& name = force.getPerBondParameterName(i);
variables[name] = "bondParams"+params->getParameterSuffix(i);
}
needEnergyParamDerivs = (force.getNumEnergyParameterDerivatives() > 0);
if (needEnergyParamDerivs)
extraArgs << ", __global mixed* restrict energyParamDerivs";
if (force.getNumGlobalParameters() > 0) {
globals = OpenCLArray::create<float>(cl, force.getNumGlobalParameters(), "customCentroidBondGlobals");
globals->upload(globalParamValues);
......@@ -4959,7 +4962,7 @@ void OpenCLCalcCustomCentroidBondForceKernel::initialize(const System& system, c
atomNames.push_back("P"+index);
posNames.push_back("pos"+index);
}
stringstream compute;
stringstream compute, initParamDerivs, saveParamDerivs;
for (int i = 0; i < groupsPerBond; i++) {
compute<<"int group"<<(i+1)<<" = bondGroups[index+"<<(i*numBonds)<<"];\n";
compute<<"real4 pos"<<(i+1)<<" = centerPositions[group"<<(i+1)<<"];\n";
......@@ -5031,6 +5034,21 @@ void OpenCLCalcCustomCentroidBondForceKernel::initialize(const System& system, c
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = globalParams"<<i<<"[index];\n";
}
forceExpressions["energy += "] = energyExpression;
if (needEnergyParamDerivs) {
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string paramName = force.getEnergyParameterDerivativeName(i);
cl.addEnergyParameterDerivative(paramName);
Lepton::ParsedExpression derivExpression = energyExpression.differentiate(paramName).optimize();
forceExpressions[string("energyParamDeriv")+cl.intToString(i)+" += "] = derivExpression;
initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
}
const vector<string>& allParamDerivNames = cl.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
saveParamDerivs << "energyParamDerivs[get_global_id(0)*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
}
compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
// Finally, apply forces to groups.
......@@ -5119,6 +5137,8 @@ void OpenCLCalcCustomCentroidBondForceKernel::initialize(const System& system, c
replacements["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
replacements["EXTRA_ARGS"] = extraArgs.str();
replacements["COMPUTE_FORCE"] = compute.str();
replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customCentroidBond, replacements));
index = 0;
computeCentersKernel = cl::Kernel(program, "computeGroupCenters");
......@@ -5134,6 +5154,8 @@ void OpenCLCalcCustomCentroidBondForceKernel::initialize(const System& system, c
groupForcesKernel.setArg<cl::Buffer>(index++, centerPositions->getDeviceBuffer());
groupForcesKernel.setArg<cl::Buffer>(index++, bondGroups->getDeviceBuffer());
index += 5; // Periodic box information
if (needEnergyParamDerivs)
index++; // Deriv buffer hasn't been created yet.
for (int i = 0; i < tabulatedFunctions.size(); i++)
groupForcesKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
if (globals != NULL)
......@@ -5165,6 +5187,8 @@ double OpenCLCalcCustomCentroidBondForceKernel::execute(ContextImpl& context, bo
cl.executeKernel(computeCentersKernel, OpenCLContext::TileSize*numGroups);
groupForcesKernel.setArg<cl::Buffer>(1, cl.getEnergyBuffer().getDeviceBuffer());
setPeriodicBoxArgs(cl, groupForcesKernel, 4);
if (needEnergyParamDerivs)
groupForcesKernel.setArg<cl::Memory>(9, cl.getEnergyParamDerivBuffer().getDeviceBuffer());
cl.executeKernel(groupForcesKernel, numBonds);
applyForcesKernel.setArg<cl::Buffer>(4, cl.getLongForceBuffer().getDeviceBuffer());
cl.executeKernel(applyForcesKernel, OpenCLContext::TileSize*numGroups);
......
......@@ -116,10 +116,12 @@ __kernel void computeGroupForces(__global long* restrict groupForce, __global mi
__global const int* restrict bondGroups, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ
EXTRA_ARGS) {
mixed energy = 0;
INIT_PARAM_DERIVS
for (int index = get_global_id(0); index < NUM_BONDS; index += get_global_size(0)) {
COMPUTE_FORCE
}
energyBuffer[get_global_id(0)] += energy;
SAVE_PARAM_DERIVS
}
/**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment