Commit 0c4af105 authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimizations to CustomGBForce

parent e7a00c6a
...@@ -1799,8 +1799,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1799,8 +1799,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
// Record parameters and exclusions. // Record parameters and exclusions.
int numParticles = force.getNumParticles(); int numParticles = force.getNumParticles();
params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), numParticles, "customGBParameters"); params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), numParticles, "customGBParameters", true);
computedValues = new OpenCLParameterSet(cl, force.getNumComputedValues(), numParticles, "customGBComputedValues"); computedValues = new OpenCLParameterSet(cl, force.getNumComputedValues(), numParticles, "customGBComputedValues", true);
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
globals = new OpenCLArray<cl_float>(cl, force.getNumGlobalParameters(), "customGBGlobals", false, CL_MEM_READ_ONLY); globals = new OpenCLArray<cl_float>(cl, force.getNumGlobalParameters(), "customGBGlobals", false, CL_MEM_READ_ONLY);
vector<vector<cl_float> > paramVector(numParticles); vector<vector<cl_float> > paramVector(numParticles);
...@@ -1877,17 +1877,25 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1877,17 +1877,25 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
valueDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize()); valueDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
} }
vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms()); vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms());
vector<bool> needChainForValue(force.getNumComputedValues(), false);
for (int i = 0; i < force.getNumEnergyTerms(); i++) { for (int i = 0; i < force.getNumEnergyTerms(); i++) {
string expression; string expression;
CustomGBForce::ComputationType type; CustomGBForce::ComputationType type;
force.getEnergyTermParameters(i, expression, type); force.getEnergyTermParameters(i, expression, type);
Lepton::ParsedExpression ex = Lepton::Parser::parse(expression, functions).optimize(); Lepton::ParsedExpression ex = Lepton::Parser::parse(expression, functions).optimize();
for (int j = 0; j < force.getNumComputedValues(); j++) { for (int j = 0; j < force.getNumComputedValues(); j++) {
if (type == CustomGBForce::SingleParticle) if (type == CustomGBForce::SingleParticle) {
energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize()); energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
if (!isZeroExpression(energyDerivExpressions[i].back()))
needChainForValue[j] = true;
}
else { else {
energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"1").optimize()); energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"1").optimize());
if (!isZeroExpression(energyDerivExpressions[i].back()))
needChainForValue[j] = true;
energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"2").optimize()); energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"2").optimize());
if (!isZeroExpression(energyDerivExpressions[i].back()))
needChainForValue[j] = true;
} }
} }
} }
...@@ -1895,11 +1903,11 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1895,11 +1903,11 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
bool useLong = (cl.getSupports64BitGlobalAtomics() && !deviceIsCpu); bool useLong = (cl.getSupports64BitGlobalAtomics() && !deviceIsCpu);
if (useLong) { if (useLong) {
longEnergyDerivs = new OpenCLArray<cl_long>(cl, force.getNumComputedValues()*cl.getPaddedNumAtoms(), "customGBLongEnergyDerivatives"); longEnergyDerivs = new OpenCLArray<cl_long>(cl, force.getNumComputedValues()*cl.getPaddedNumAtoms(), "customGBLongEnergyDerivatives");
energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "customGBEnergyDerivatives"); energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "customGBEnergyDerivatives", true);
} }
else else
energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms()*cl.getNonbondedUtilities().getNumForceBuffers(), "customGBEnergyDerivatives"); energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms()*cl.getNonbondedUtilities().getNumForceBuffers(), "customGBEnergyDerivatives", true);
// Create the kernels. // Create the kernels.
bool useCutoff = (force.getNonbondedMethod() != CustomGBForce::NoCutoff); bool useCutoff = (force.getNonbondedMethod() != CustomGBForce::NoCutoff);
...@@ -1932,18 +1940,23 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1932,18 +1940,23 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename); n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
n2ValueSource << OpenCLExpressionUtilities::createExpressions(n2ValueExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); n2ValueSource << OpenCLExpressionUtilities::createExpressions(n2ValueExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_VALUE"] = n2ValueSource.str(); string n2ValueStr = n2ValueSource.str();
replacements["COMPUTE_VALUE"] = n2ValueStr;
stringstream extraArgs, loadLocal1, loadLocal2, load1, load2; stringstream extraArgs, loadLocal1, loadLocal2, load1, load2;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", __global const float* globals"; extraArgs << ", __global const float* globals";
pairValueUsesParam.resize(params->getBuffers().size(), false);
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i]; const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
string paramName = "params"+intToString(i+1); string paramName = "params"+intToString(i+1);
extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << paramName << ", __local " << buffer.getType() << "* restrict local_" << paramName; if (n2ValueStr.find(paramName+"1") != n2ValueStr.npos || n2ValueStr.find(paramName+"2") != n2ValueStr.npos) {
loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n"; extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << paramName << ", __local " << buffer.getType() << "* restrict local_" << paramName;
loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n"; loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n";
load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n"; loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n";
load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n"; load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n";
load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n";
pairValueUsesParam[i] = true;
}
} }
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
...@@ -2054,15 +2067,19 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2054,15 +2067,19 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
n2EnergyExpressions["dEdR += "] = Lepton::Parser::parse(expression, functions).differentiate("r").optimize(); n2EnergyExpressions["dEdR += "] = Lepton::Parser::parse(expression, functions).differentiate("r").optimize();
if (useLong) { if (useLong) {
for (int j = 0; j < force.getNumComputedValues(); j++) { for (int j = 0; j < force.getNumComputedValues(); j++) {
string index = intToString(j+1); if (needChainForValue[j]) {
n2EnergyExpressions["/*"+intToString(i+1)+"*/ deriv"+index+"_1 += "] = energyDerivExpressions[i][2*j]; string index = intToString(j+1);
n2EnergyExpressions["/*"+intToString(i+1)+"*/ deriv"+index+"_2 += "] = energyDerivExpressions[i][2*j+1]; n2EnergyExpressions["/*"+intToString(i+1)+"*/ deriv"+index+"_1 += "] = energyDerivExpressions[i][2*j];
n2EnergyExpressions["/*"+intToString(i+1)+"*/ deriv"+index+"_2 += "] = energyDerivExpressions[i][2*j+1];
}
} }
} }
else { else {
for (int j = 0; j < force.getNumComputedValues(); j++) { for (int j = 0; j < force.getNumComputedValues(); j++) {
n2EnergyExpressions["/*"+intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j, "_1")+" += "] = energyDerivExpressions[i][2*j]; if (needChainForValue[j]) {
n2EnergyExpressions["/*"+intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j, "_2")+" += "] = energyDerivExpressions[i][2*j+1]; n2EnergyExpressions["/*"+intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j, "_1")+" += "] = energyDerivExpressions[i][2*j];
n2EnergyExpressions["/*"+intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j, "_2")+" += "] = energyDerivExpressions[i][2*j+1];
}
} }
} }
if (exclude) if (exclude)
...@@ -2072,27 +2089,36 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2072,27 +2089,36 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
n2EnergySource << "}\n"; n2EnergySource << "}\n";
} }
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_INTERACTION"] = n2EnergySource.str(); string n2EnergyStr = n2EnergySource.str();
replacements["COMPUTE_INTERACTION"] = n2EnergyStr;
stringstream extraArgs, loadLocal1, loadLocal2, clearLocal, load1, load2, declare1, recordDeriv, storeDerivs1, storeDerivs2, declareTemps, setTemps; stringstream extraArgs, loadLocal1, loadLocal2, clearLocal, load1, load2, declare1, recordDeriv, storeDerivs1, storeDerivs2, declareTemps, setTemps;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", __global const float* globals"; extraArgs << ", __global const float* globals";
pairEnergyUsesParam.resize(params->getBuffers().size(), false);
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i]; const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
string paramName = "params"+intToString(i+1); string paramName = "params"+intToString(i+1);
extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << paramName << ", __local " << buffer.getType() << "* restrict local_" << paramName; if (n2EnergyStr.find(paramName+"1") != n2EnergyStr.npos || n2EnergyStr.find(paramName+"2") != n2EnergyStr.npos) {
loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n"; extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << paramName << ", __local " << buffer.getType() << "* restrict local_" << paramName;
loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n"; loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n";
load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n"; loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n";
load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n"; load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n";
load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n";
pairEnergyUsesParam[i] = true;
}
} }
pairEnergyUsesValue.resize(computedValues->getBuffers().size(), false);
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) { for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i]; const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
string valueName = "values"+intToString(i+1); string valueName = "values"+intToString(i+1);
extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << valueName << ", __local " << buffer.getType() << "* restrict local_" << valueName; if (n2EnergyStr.find(valueName+"1") != n2EnergyStr.npos || n2EnergyStr.find(valueName+"2") != n2EnergyStr.npos) {
loadLocal1 << "local_" << valueName << "[localAtomIndex] = " << valueName << "1;\n"; extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << valueName << ", __local " << buffer.getType() << "* restrict local_" << valueName;
loadLocal2 << "local_" << valueName << "[localAtomIndex] = global_" << valueName << "[j];\n"; loadLocal1 << "local_" << valueName << "[localAtomIndex] = " << valueName << "1;\n";
load1 << buffer.getType() << " " << valueName << "1 = global_" << valueName << "[atom1];\n"; loadLocal2 << "local_" << valueName << "[localAtomIndex] = global_" << valueName << "[j];\n";
load2 << buffer.getType() << " " << valueName << "2 = local_" << valueName << "[atom2];\n"; load1 << buffer.getType() << " " << valueName << "1 = global_" << valueName << "[atom1];\n";
load2 << buffer.getType() << " " << valueName << "2 = local_" << valueName << "[atom2];\n";
pairEnergyUsesValue[i] = true;
}
} }
if (useLong) { if (useLong) {
extraArgs << ", __global long* restrict derivBuffers"; extraArgs << ", __global long* restrict derivBuffers";
...@@ -2193,6 +2219,9 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2193,6 +2219,9 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++)
reduce << "REDUCE_VALUE(derivBuffers" << intToString(i+1) << ", " << energyDerivs->getBuffers()[i].getType() << ")\n"; reduce << "REDUCE_VALUE(derivBuffers" << intToString(i+1) << ", " << energyDerivs->getBuffers()[i].getType() << ")\n";
} }
// Compute the various expressions.
map<string, string> variables; map<string, string> variables;
variables["x"] = "pos.x"; variables["x"] = "pos.x";
variables["y"] = "pos.y"; variables["y"] = "pos.y";
...@@ -2203,7 +2232,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2203,7 +2232,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables[force.getGlobalParameterName(i)] = "globals["+intToString(i)+"]"; variables[force.getGlobalParameterName(i)] = "globals["+intToString(i)+"]";
for (int i = 0; i < force.getNumComputedValues(); i++) for (int i = 0; i < force.getNumComputedValues(); i++)
variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]"); variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
map<string, Lepton::ParsedExpression> energyExpressions; map<string, Lepton::ParsedExpression> expressions;
for (int i = 0; i < force.getNumEnergyTerms(); i++) { for (int i = 0; i < force.getNumEnergyTerms(); i++) {
string expression; string expression;
CustomGBForce::ComputationType type; CustomGBForce::ComputationType type;
...@@ -2211,25 +2240,38 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2211,25 +2240,38 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
if (type != CustomGBForce::SingleParticle) if (type != CustomGBForce::SingleParticle)
continue; continue;
Lepton::ParsedExpression parsed = Lepton::Parser::parse(expression, functions).optimize(); Lepton::ParsedExpression parsed = Lepton::Parser::parse(expression, functions).optimize();
energyExpressions["/*"+intToString(i+1)+"*/ energy += "] = parsed; expressions["/*"+intToString(i+1)+"*/ energy += "] = parsed;
for (int j = 0; j < force.getNumComputedValues(); j++) for (int j = 0; j < force.getNumComputedValues(); j++)
energyExpressions["/*"+intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j)+" += "] = energyDerivExpressions[i][j]; expressions["/*"+intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j)+" += "] = energyDerivExpressions[i][j];
Lepton::ParsedExpression gradx = parsed.differentiate("x").optimize(); Lepton::ParsedExpression gradx = parsed.differentiate("x").optimize();
Lepton::ParsedExpression grady = parsed.differentiate("y").optimize(); Lepton::ParsedExpression grady = parsed.differentiate("y").optimize();
Lepton::ParsedExpression gradz = parsed.differentiate("z").optimize(); Lepton::ParsedExpression gradz = parsed.differentiate("z").optimize();
if (!isZeroExpression(gradx)) if (!isZeroExpression(gradx))
energyExpressions["/*"+intToString(i+1)+"*/ force.x -= "] = gradx; expressions["/*"+intToString(i+1)+"*/ force.x -= "] = gradx;
if (!isZeroExpression(grady)) if (!isZeroExpression(grady))
energyExpressions["/*"+intToString(i+1)+"*/ force.y -= "] = grady; expressions["/*"+intToString(i+1)+"*/ force.y -= "] = grady;
if (!isZeroExpression(gradz)) if (!isZeroExpression(gradz))
energyExpressions["/*"+intToString(i+1)+"*/ force.z -= "] = gradz; expressions["/*"+intToString(i+1)+"*/ force.z -= "] = gradz;
}
for (int i = 1; i < force.getNumComputedValues(); i++)
for (int j = 0; j < i; j++)
expressions["float dV"+intToString(i)+"dV"+intToString(j)+" = "] = valueDerivExpressions[i][j];
compute << OpenCLExpressionUtilities::createExpressions(expressions, variables, functionDefinitions, "temp", prefix+"functionParams");
// Record values.
compute << "forceBuffers[index] = forceBuffers[index]+force;\n";
for (int i = 1; i < force.getNumComputedValues(); i++) {
compute << "float totalDeriv"<<i<<" = dV"<<i<<"dV0";
for (int j = 1; j < i; j++)
compute << " + totalDeriv"<<j<<"*dV"<<i<<"dV"<<j;
compute << ";\n";
compute << "deriv"<<(i+1)<<" *= totalDeriv"<<i<<";\n";
} }
compute << OpenCLExpressionUtilities::createExpressions(energyExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) { for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
string index = intToString(i+1); string index = intToString(i+1);
compute << "derivBuffers" << index << "[index] = deriv" << index << ";\n"; compute << "derivBuffers" << index << "[index] = deriv" << index << ";\n";
} }
compute << "forceBuffers[index] = forceBuffers[index]+force;\n";
map<string, string> replacements; map<string, string> replacements;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["REDUCE_DERIVATIVES"] = reduce.str(); replacements["REDUCE_DERIVATIVES"] = reduce.str();
...@@ -2324,8 +2366,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2324,8 +2366,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
const string& name = force.getPerParticleParameterName(i); const string& name = force.getPerParticleParameterName(i);
variables.push_back(makeVariable(name+"1", prefix+"params"+params->getParameterSuffix(i, "1"))); variables.push_back(makeVariable(name+"1", prefix+"params"+params->getParameterSuffix(i, "1")));
variables.push_back(makeVariable(name+"2", prefix+"params"+params->getParameterSuffix(i, "2"))); variables.push_back(makeVariable(name+"2", prefix+"params"+params->getParameterSuffix(i, "2")));
rename[name+"1"] = name+"2"; rename[name+"1"] = name+"2";
rename[name+"2"] = name+"1"; rename[name+"2"] = name+"1";
} }
map<string, Lepton::ParsedExpression> derivExpressions; map<string, Lepton::ParsedExpression> derivExpressions;
stringstream chainSource; stringstream chainSource;
...@@ -2333,75 +2375,44 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2333,75 +2375,44 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
derivExpressions["float dV0dR1 = "] = dVdR; derivExpressions["float dV0dR1 = "] = dVdR;
derivExpressions["float dV0dR2 = "] = dVdR.renameVariables(rename); derivExpressions["float dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams"); chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
if (useExclusionsForValue) if (needChainForValue[0]) {
chainSource << "if (!isExcluded) {\n"; if (useExclusionsForValue)
chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n"; chainSource << "if (!isExcluded) {\n";
chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n"; chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n";
if (useExclusionsForValue) chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n";
chainSource << "}\n"; if (useExclusionsForValue)
variables = globalVariables; chainSource << "}\n";
map<string, string> rename1;
map<string, string> rename2;
variables.push_back(makeVariable("x1", "posq1.x"));
variables.push_back(makeVariable("y1", "posq1.y"));
variables.push_back(makeVariable("z1", "posq1.z"));
variables.push_back(makeVariable("x2", "posq2.x"));
variables.push_back(makeVariable("y2", "posq2.y"));
variables.push_back(makeVariable("z2", "posq2.z"));
rename1["x"] = "x1";
rename1["y"] = "y1";
rename1["z"] = "z1";
rename2["x"] = "x2";
rename2["y"] = "y2";
rename2["z"] = "z2";
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
const string& name = force.getPerParticleParameterName(i);
variables.push_back(makeVariable(name+"1", prefix+"params"+params->getParameterSuffix(i, "1")));
variables.push_back(makeVariable(name+"2", prefix+"params"+params->getParameterSuffix(i, "2")));
rename1[name] = name+"1";
rename2[name] = name+"2";
} }
for (int i = 0; i < force.getNumComputedValues(); i++) { for (int i = 1; i < force.getNumComputedValues(); i++) {
const string& name = computedValueNames[i]; if (needChainForValue[i]) {
variables.push_back(makeVariable(name+"1", prefix+"values"+computedValues->getParameterSuffix(i, "1"))); chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n";
variables.push_back(makeVariable(name+"2", prefix+"values"+computedValues->getParameterSuffix(i, "2"))); chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "2") << ";\n";
rename1[name] = name+"1";
rename2[name] = name+"2";
if (i == 0)
continue;
string is = intToString(i);
chainSource << "float dV"+is+"dR1 = 0;\n";
chainSource << "float dV"+is+"dR2 = 0;\n";
for (int j = 0; j < i; j++) {
string js = intToString(j);
Lepton::ParsedExpression dVdV = Lepton::Parser::parse(computedValueExpressions[i], functions).differentiate(computedValueNames[j]).optimize();
derivExpressions.clear();
derivExpressions["dV"+is+"dR1 += dV"+js+"dR1*"] = dVdV.renameVariables(rename1);
derivExpressions["dV"+is+"dR2 += dV"+js+"dR2*"] = dVdV.renameVariables(rename2);
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp"+is+"_"+js+"_", prefix+"functionParams");
} }
chainSource << "tempForce -= dV"<< is << "dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n";
chainSource << "tempForce -= dV"<< is << "dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "2") << ";\n";
} }
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = chainSource.str(); string chainStr = chainSource.str();
replacements["COMPUTE_FORCE"] = chainStr;
string source = cl.replaceStrings(OpenCLKernelSources::customGBChainRule, replacements); string source = cl.replaceStrings(OpenCLKernelSources::customGBChainRule, replacements);
vector<OpenCLNonbondedUtilities::ParameterInfo> parameters; vector<OpenCLNonbondedUtilities::ParameterInfo> parameters;
vector<OpenCLNonbondedUtilities::ParameterInfo> arguments; vector<OpenCLNonbondedUtilities::ParameterInfo> arguments;
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i]; const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
string paramName = prefix+"params"+intToString(i+1); string paramName = prefix+"params"+intToString(i+1);
parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory())); if (chainStr.find(paramName+"1") != chainStr.npos || chainStr.find(paramName+"2") != chainStr.npos)
parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
} }
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) { for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i]; const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
string paramName = prefix+"values"+intToString(i+1); string paramName = prefix+"values"+intToString(i+1);
parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory())); if (chainStr.find(paramName+"1") != chainStr.npos || chainStr.find(paramName+"2") != chainStr.npos)
parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
} }
for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) { for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i]; if (needChainForValue[i]) {
string paramName = prefix+"dEdV"+intToString(i+1); const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory())); string paramName = prefix+"dEdV"+intToString(i+1);
parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
}
} }
if (globals != NULL) { if (globals != NULL) {
globals->upload(globalParamValues); globals->upload(globalParamValues);
...@@ -2465,9 +2476,11 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -2465,9 +2476,11 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
if (globals != NULL) if (globals != NULL)
pairValueKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer()); pairValueKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i]; if (pairValueUsesParam[i]) {
pairValueKernel.setArg<cl::Memory>(index++, buffer.getMemory()); const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL); pairValueKernel.setArg<cl::Memory>(index++, buffer.getMemory());
pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
}
} }
if (tabulatedFunctionParams != NULL) { if (tabulatedFunctionParams != NULL) {
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
...@@ -2515,14 +2528,18 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -2515,14 +2528,18 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
if (globals != NULL) if (globals != NULL)
pairEnergyKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer()); pairEnergyKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i]; if (pairEnergyUsesParam[i]) {
pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory()); const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL); pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
}
} }
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) { for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i]; if (pairEnergyUsesValue[i]) {
pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory()); const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL); pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
}
} }
if (useLong) { if (useLong) {
pairEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs->getDeviceBuffer()); pairEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs->getDeviceBuffer());
......
...@@ -687,6 +687,7 @@ private: ...@@ -687,6 +687,7 @@ private:
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues; std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray<mm_float4>*> tabulatedFunctions; std::vector<OpenCLArray<mm_float4>*> tabulatedFunctions;
std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue;
System& system; System& system;
cl::Kernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel; cl::Kernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel;
}; };
......
...@@ -32,30 +32,33 @@ ...@@ -32,30 +32,33 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
OpenCLParameterSet::OpenCLParameterSet(OpenCLContext& context, int numParameters, int numObjects, const string& name) : OpenCLParameterSet::OpenCLParameterSet(OpenCLContext& context, int numParameters, int numObjects, const string& name, bool bufferPerParameter) :
context(context), numParameters(numParameters), numObjects(numObjects), name(name) { context(context), numParameters(numParameters), numObjects(numObjects), name(name) {
int params = numParameters; int params = numParameters;
int bufferCount = 0; int bufferCount = 0;
try { try {
while (params > 2) { if (!bufferPerParameter) {
cl::Buffer* buf = new cl::Buffer(context.getContext(), CL_MEM_READ_WRITE, numObjects*sizeof(mm_float4)); while (params > 2) {
std::stringstream name; cl::Buffer* buf = new cl::Buffer(context.getContext(), CL_MEM_READ_WRITE, numObjects*sizeof(mm_float4));
name << "param" << (++bufferCount); std::stringstream name;
buffers.push_back(OpenCLNonbondedUtilities::ParameterInfo(name.str(), "float", 4, sizeof(mm_float4), *buf)); name << "param" << (++bufferCount);
params -= 4; buffers.push_back(OpenCLNonbondedUtilities::ParameterInfo(name.str(), "float", 4, sizeof(mm_float4), *buf));
} params -= 4;
if (params > 1) { }
cl::Buffer* buf = new cl::Buffer(context.getContext(), CL_MEM_READ_WRITE, numObjects*sizeof(mm_float2)); if (params > 1) {
std::stringstream name; cl::Buffer* buf = new cl::Buffer(context.getContext(), CL_MEM_READ_WRITE, numObjects*sizeof(mm_float2));
name << "param" << (++bufferCount); std::stringstream name;
buffers.push_back(OpenCLNonbondedUtilities::ParameterInfo(name.str(), "float", 2, sizeof(mm_float2), *buf)); name << "param" << (++bufferCount);
params -= 2; buffers.push_back(OpenCLNonbondedUtilities::ParameterInfo(name.str(), "float", 2, sizeof(mm_float2), *buf));
params -= 2;
}
} }
if (params > 0) { while (params > 0) {
cl::Buffer* buf = new cl::Buffer(context.getContext(), CL_MEM_READ_WRITE, numObjects*sizeof(cl_float)); cl::Buffer* buf = new cl::Buffer(context.getContext(), CL_MEM_READ_WRITE, numObjects*sizeof(cl_float));
std::stringstream name; std::stringstream name;
name << "param" << (++bufferCount); name << "param" << (++bufferCount);
buffers.push_back(OpenCLNonbondedUtilities::ParameterInfo(name.str(), "float", 1, sizeof(cl_float), *buf)); buffers.push_back(OpenCLNonbondedUtilities::ParameterInfo(name.str(), "float", 1, sizeof(cl_float), *buf));
params--;
} }
} }
catch (cl::Error err) { catch (cl::Error err) {
...@@ -106,6 +109,7 @@ void OpenCLParameterSet::getParameterValues(vector<vector<cl_float> >& values) c ...@@ -106,6 +109,7 @@ void OpenCLParameterSet::getParameterValues(vector<vector<cl_float> >& values) c
context.getQueue().enqueueReadBuffer(reinterpret_cast<cl::Buffer&>(buffers[i].getMemory()), CL_TRUE, 0, numObjects*buffers[i].getSize(), &data[0]); context.getQueue().enqueueReadBuffer(reinterpret_cast<cl::Buffer&>(buffers[i].getMemory()), CL_TRUE, 0, numObjects*buffers[i].getSize(), &data[0]);
for (int j = 0; j < numObjects; j++) for (int j = 0; j < numObjects; j++)
values[j][base] = data[j]; values[j][base] = data[j];
base++;
} }
else else
throw OpenMMException("Internal error: Unknown buffer type in OpenCLParameterSet"); throw OpenMMException("Internal error: Unknown buffer type in OpenCLParameterSet");
...@@ -151,6 +155,7 @@ void OpenCLParameterSet::setParameterValues(const vector<vector<cl_float> >& val ...@@ -151,6 +155,7 @@ void OpenCLParameterSet::setParameterValues(const vector<vector<cl_float> >& val
for (int j = 0; j < numObjects; j++) for (int j = 0; j < numObjects; j++)
data[j] = values[j][base]; data[j] = values[j][base];
context.getQueue().enqueueWriteBuffer(reinterpret_cast<cl::Buffer&>(buffers[i].getMemory()), CL_TRUE, 0, numObjects*buffers[i].getSize(), &data[0]); context.getQueue().enqueueWriteBuffer(reinterpret_cast<cl::Buffer&>(buffers[i].getMemory()), CL_TRUE, 0, numObjects*buffers[i].getSize(), &data[0]);
base++;
} }
else else
throw OpenMMException("Internal error: Unknown buffer type in OpenCLParameterSet"); throw OpenMMException("Internal error: Unknown buffer type in OpenCLParameterSet");
......
...@@ -49,8 +49,10 @@ public: ...@@ -49,8 +49,10 @@ public:
* @param numParameters the number of parameters for each object * @param numParameters the number of parameters for each object
* @param numObjects the number of objects to store parameter values for * @param numObjects the number of objects to store parameter values for
* @param name the name of the parameter set * @param name the name of the parameter set
* @param bufferPerParameter if true, a separate cl::Buffer is created for each parameter. If false,
* multiple parameters may be combined into a single buffer.
*/ */
OpenCLParameterSet(OpenCLContext& context, int numParameters, int numObjects, const std::string& name); OpenCLParameterSet(OpenCLContext& context, int numParameters, int numObjects, const std::string& name, bool bufferPerParameter=false);
~OpenCLParameterSet(); ~OpenCLParameterSet();
/** /**
* Get the number of parameters. * Get the number of parameters.
......
#define REDUCE_VALUE(NAME, TYPE) \ #define REDUCE_VALUE(NAME, TYPE) {\
TYPE sum = NAME[index]; \ TYPE sum = NAME[index]; \
for (int i = index+bufferSize; i < totalSize; i += bufferSize) \ for (int i = index+bufferSize; i < totalSize; i += bufferSize) \
sum += NAME[i]; \ sum += NAME[i]; \
NAME[index] = sum; NAME[index] = sum; \
}
/** /**
* Reduce the derivatives computed in the N^2 energy kernel, and compute all per-particle energy terms. * Reduce the derivatives computed in the N^2 energy kernel, and compute all per-particle energy terms.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment