Commit 74efa95f authored by peastman's avatar peastman
Browse files

OpenCL implementation of parameter derivatives for CustomGBForce

parent df07fbe9
...@@ -808,13 +808,15 @@ public: ...@@ -808,13 +808,15 @@ public:
void copyParametersToContext(ContextImpl& context, const CustomGBForce& force); void copyParametersToContext(ContextImpl& context, const CustomGBForce& force);
private: private:
double cutoff; double cutoff;
bool hasInitializedKernels, needParameterGradient; bool hasInitializedKernels, needParameterGradient, needEnergyParamDerivs;
int maxTiles, numComputedValues; int maxTiles, numComputedValues;
OpenCLContext& cl; OpenCLContext& cl;
OpenCLParameterSet* params; OpenCLParameterSet* params;
OpenCLParameterSet* computedValues; OpenCLParameterSet* computedValues;
OpenCLParameterSet* energyDerivs; OpenCLParameterSet* energyDerivs;
OpenCLParameterSet* energyDerivChain; OpenCLParameterSet* energyDerivChain;
std::vector<OpenCLParameterSet*> dValuedParam;
std::vector<OpenCLArray*> dValue0dParam;
OpenCLArray* longEnergyDerivs; OpenCLArray* longEnergyDerivs;
OpenCLArray* globals; OpenCLArray* globals;
OpenCLArray* valueBuffers; OpenCLArray* valueBuffers;
......
...@@ -204,7 +204,7 @@ void OpenCLBondedUtilities::initialize(const System& system) { ...@@ -204,7 +204,7 @@ void OpenCLBondedUtilities::initialize(const System& system) {
for (int i = 0; i < (int) arguments.size(); i++) for (int i = 0; i < (int) arguments.size(); i++)
s<<", __global "<<argTypes[i]<<"* customArg"<<(i+1); s<<", __global "<<argTypes[i]<<"* customArg"<<(i+1);
if (energyParameterDerivatives.size() > 0) if (energyParameterDerivatives.size() > 0)
s<<", __global mixed* energyParamDerivs"; s<<", __global mixed* restrict energyParamDerivs";
s<<") {\n"; s<<") {\n";
s<<"mixed energy = 0;\n"; s<<"mixed energy = 0;\n";
for (int i = 0; i < energyParameterDerivatives.size(); i++) for (int i = 0; i < energyParameterDerivatives.size(); i++)
...@@ -219,7 +219,7 @@ void OpenCLBondedUtilities::initialize(const System& system) { ...@@ -219,7 +219,7 @@ void OpenCLBondedUtilities::initialize(const System& system) {
for (int i = 0; i < energyParameterDerivatives.size(); i++) for (int i = 0; i < energyParameterDerivatives.size(); i++)
for (int index = 0; index < numDerivs; index++) for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == energyParameterDerivatives[i]) if (allParamDerivNames[index] == energyParameterDerivatives[i])
s<<"energyParamDerivs[get_global_id(0)*"<<numDerivs<<"+"<<i<<"] += energyParamDeriv"<<i<<";\n"; s<<"energyParamDerivs[get_global_id(0)*"<<numDerivs<<"+"<<index<<"] += energyParamDeriv"<<i<<";\n";
s<<"}\n"; s<<"}\n";
map<string, string> defines; map<string, string> defines;
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms()); defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
......
...@@ -3010,6 +3010,10 @@ OpenCLCalcCustomGBForceKernel::~OpenCLCalcCustomGBForceKernel() { ...@@ -3010,6 +3010,10 @@ OpenCLCalcCustomGBForceKernel::~OpenCLCalcCustomGBForceKernel() {
delete longValueBuffers; delete longValueBuffers;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
delete tabulatedFunctions[i]; delete tabulatedFunctions[i];
for (int i = 0; i < dValue0dParam.size(); i++)
delete dValue0dParam[i];
for (int i = 0; i < dValuedParam.size(); i++)
delete dValuedParam[i];
} }
void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const CustomGBForce& force) { void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const CustomGBForce& force) {
...@@ -3101,18 +3105,24 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3101,18 +3105,24 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
vector<vector<Lepton::ParsedExpression> > valueGradientExpressions(force.getNumComputedValues()); vector<vector<Lepton::ParsedExpression> > valueGradientExpressions(force.getNumComputedValues());
vector<vector<Lepton::ParsedExpression> > valueDerivExpressions(force.getNumComputedValues()); vector<vector<Lepton::ParsedExpression> > valueDerivExpressions(force.getNumComputedValues());
vector<vector<Lepton::ParsedExpression> > valueParamDerivExpressions(force.getNumComputedValues());
needParameterGradient = false; needParameterGradient = false;
for (int i = 1; i < force.getNumComputedValues(); i++) { for (int i = 0; i < force.getNumComputedValues(); i++) {
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize(); Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
valueGradientExpressions[i].push_back(ex.differentiate("x").optimize()); if (i > 0) {
valueGradientExpressions[i].push_back(ex.differentiate("y").optimize()); valueGradientExpressions[i].push_back(ex.differentiate("x").optimize());
valueGradientExpressions[i].push_back(ex.differentiate("z").optimize()); valueGradientExpressions[i].push_back(ex.differentiate("y").optimize());
if (!isZeroExpression(valueGradientExpressions[i][0]) || !isZeroExpression(valueGradientExpressions[i][1]) || !isZeroExpression(valueGradientExpressions[i][2])) valueGradientExpressions[i].push_back(ex.differentiate("z").optimize());
needParameterGradient = true; if (!isZeroExpression(valueGradientExpressions[i][0]) || !isZeroExpression(valueGradientExpressions[i][1]) || !isZeroExpression(valueGradientExpressions[i][2]))
for (int j = 0; j < i; j++) needParameterGradient = true;
valueDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize()); for (int j = 0; j < i; j++)
valueDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
}
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
valueParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).optimize());
} }
vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms()); vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms());
vector<vector<Lepton::ParsedExpression> > energyParamDerivExpressions(force.getNumEnergyTerms());
vector<bool> needChainForValue(force.getNumComputedValues(), false); vector<bool> needChainForValue(force.getNumComputedValues(), false);
for (int i = 0; i < force.getNumEnergyTerms(); i++) { for (int i = 0; i < force.getNumEnergyTerms(); i++) {
string expression; string expression;
...@@ -3134,6 +3144,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3134,6 +3144,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
needChainForValue[j] = true; needChainForValue[j] = true;
} }
} }
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
energyParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).optimize());
} }
bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU); bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
bool useLong = cl.getSupports64BitGlobalAtomics(); bool useLong = cl.getSupports64BitGlobalAtomics();
...@@ -3144,6 +3156,18 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3144,6 +3156,18 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
else else
energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms()*cl.getNonbondedUtilities().getNumForceBuffers(), "customGBEnergyDerivatives", true); energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms()*cl.getNonbondedUtilities().getNumForceBuffers(), "customGBEnergyDerivatives", true);
energyDerivChain = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "customGBEnergyDerivativeChain", true); energyDerivChain = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "customGBEnergyDerivativeChain", true);
int elementSize = (cl.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
needEnergyParamDerivs = (force.getNumEnergyParameterDerivatives() > 0);
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
dValuedParam.push_back(new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "dValuedParam", true));
if (useLong)
dValue0dParam.push_back(OpenCLArray::create<cl_long>(cl, cl.getPaddedNumAtoms(), "dValue0dParam"));
else
dValue0dParam.push_back(new OpenCLArray(cl, cl.getPaddedNumAtoms()*cl.getNonbondedUtilities().getNumForceBuffers(), elementSize, "dValue0dParam"));
cl.addAutoclearBuffer(*dValue0dParam.back());
string name = force.getEnergyParameterDerivativeName(i);
cl.addEnergyParameterDerivative(name);
}
// Create the kernels. // Create the kernels.
...@@ -3175,11 +3199,18 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3175,11 +3199,18 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize(); Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueExpressions["tempValue1 = "] = ex; n2ValueExpressions["tempValue1 = "] = ex;
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename); n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
for (int i = 0; i < valueParamDerivExpressions[0].size(); i++) {
string variableBase = "temp_dValue0dParam"+cl.intToString(i+1);
if (!isZeroExpression(valueParamDerivExpressions[0][i])) {
n2ValueExpressions[variableBase+"_1 = "] = valueParamDerivExpressions[0][i];
n2ValueExpressions[variableBase+"_2 = "] = valueParamDerivExpressions[0][i].renameVariables(rename);
}
}
n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp"); n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp");
map<string, string> replacements; map<string, string> replacements;
string n2ValueStr = n2ValueSource.str(); string n2ValueStr = n2ValueSource.str();
replacements["COMPUTE_VALUE"] = n2ValueStr; replacements["COMPUTE_VALUE"] = n2ValueStr;
stringstream extraArgs, loadLocal1, loadLocal2, load1, load2; stringstream extraArgs, loadLocal1, loadLocal2, load1, load2, tempDerivs1, tempDerivs2, storeDeriv1, storeDeriv2;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", __global const float* globals"; extraArgs << ", __global const float* globals";
pairValueUsesParam.resize(params->getBuffers().size(), false); pairValueUsesParam.resize(params->getBuffers().size(), false);
...@@ -3195,11 +3226,39 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3195,11 +3226,39 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
pairValueUsesParam[i] = true; pairValueUsesParam[i] = true;
} }
} }
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string derivName = "dValue0dParam"+cl.intToString(i+1);
if (useLong)
extraArgs << ", __global long* restrict global_" << derivName;
else
extraArgs << ", __global real* restrict global_" << derivName;
extraArgs << ", __local real* restrict local_" << derivName;
loadLocal2 << "local_" << derivName << "[localAtomIndex] = 0;\n";
load1 << "real " << derivName << " = 0;\n";
if (!isZeroExpression(valueParamDerivExpressions[0][i])) {
load2 << "real temp_" << derivName << "_1 = 0;\n";
load2 << "real temp_" << derivName << "_2 = 0;\n";
tempDerivs1 << derivName << " += temp_" << derivName << "_1;\n";
tempDerivs2 << "local_" << derivName << "[tbx+tj] += temp_" << derivName << "_2;\n";
if (useLong) {
storeDeriv1 << "atom_add(&global_" << derivName << "[offset1], (long) (" << derivName << "*0x100000000));\n";
storeDeriv2 << "atom_add(&global_" << derivName << "[offset2], (long) (local_" << derivName << "[get_local_id(0)]*0x100000000));\n";
}
else {
storeDeriv1 << "global_" << derivName << "[offset1] += " << derivName << ";\n";
storeDeriv2 << "global_" << derivName << "[offset2] += local_" << derivName << "[get_local_id(0)];\n";
}
}
}
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
replacements["LOAD_ATOM1_PARAMETERS"] = load1.str(); replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
replacements["LOAD_ATOM2_PARAMETERS"] = load2.str(); replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
replacements["ADD_TEMP_DERIVS1"] = tempDerivs1.str();
replacements["ADD_TEMP_DERIVS2"] = tempDerivs2.str();
replacements["STORE_PARAM_DERIVS1"] = storeDeriv1.str();
replacements["STORE_PARAM_DERIVS2"] = storeDeriv2.str();
if (useCutoff) if (useCutoff)
pairValueDefines["USE_CUTOFF"] = "1"; pairValueDefines["USE_CUTOFF"] = "1";
if (usePeriodic) if (usePeriodic)
...@@ -3224,7 +3283,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3224,7 +3283,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
{ {
// Create the kernel to reduce the N2 value and calculate other values. // Create the kernel to reduce the N2 value and calculate other values.
stringstream reductionSource, extraArgs; stringstream reductionSource, extraArgs, deriv0;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", __global const float* globals"; extraArgs << ", __global const float* globals";
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
...@@ -3238,6 +3297,22 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3238,6 +3297,22 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
extraArgs << ", __global " << buffer.getType() << "* restrict global_" << valueName; extraArgs << ", __global " << buffer.getType() << "* restrict global_" << valueName;
reductionSource << buffer.getType() << " local_" << valueName << ";\n"; reductionSource << buffer.getType() << " local_" << valueName << ";\n";
} }
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string variableName = "dValuedParam_0_"+cl.intToString(i);
if (useLong) {
extraArgs << ", __global const long* restrict dValue0dParam" << i;
deriv0 << "real " << variableName << " = (1.0f/0x100000000)*dValue0dParam[index];\n";
}
else {
extraArgs << ", __global const real* restrict dValue0dParam" << i;
deriv0 << "real " << variableName << " = dValue0dParam" << i << "[index];\n";
deriv0 << "for (int i = index+bufferSize; i < totalSize; i += bufferSize)\n";
deriv0 << " " << variableName << " += dValue0dParam" << i << "[i];\n";
}
for (int j = 0; j < dValuedParam[i]->getBuffers().size(); j++)
extraArgs << ", __global real* restrict global_dValuedParam_" << j << "_" << i;
deriv0 << "global_dValuedParam_0_" << i << "[index] = dValuedParam_0_" << i << ";\n";
}
reductionSource << "local_values" << computedValues->getParameterSuffix(0) << " = sum;\n"; reductionSource << "local_values" << computedValues->getParameterSuffix(0) << " = sum;\n";
map<string, string> variables; map<string, string> variables;
variables["x"] = "pos.x"; variables["x"] = "pos.x";
...@@ -3257,8 +3332,26 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3257,8 +3332,26 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
string valueName = "values"+cl.intToString(i+1); string valueName = "values"+cl.intToString(i+1);
reductionSource << "global_" << valueName << "[index] = local_" << valueName << ";\n"; reductionSource << "global_" << valueName << "[index] = local_" << valueName << ";\n";
} }
if (needEnergyParamDerivs) {
map<string, Lepton::ParsedExpression> derivExpressions;
for (int i = 1; i < force.getNumComputedValues(); i++) {
for (int j = 0; j < valueParamDerivExpressions[i].size(); j++)
derivExpressions["real dValuedParam_"+cl.intToString(i)+"_"+cl.intToString(j)+" = "] = valueParamDerivExpressions[i][j];
for (int j = 0; j < i; j++)
derivExpressions["real dVdV_"+cl.intToString(i)+"_"+cl.intToString(j)+" = "] = valueDerivExpressions[i][j];
}
reductionSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "derivChain_temp");
for (int i = 1; i < force.getNumComputedValues(); i++) {
for (int j = 0; j < i; j++)
for (int k = 0; k < valueParamDerivExpressions[i].size(); k++)
reductionSource << "dValuedParam_" << i << "_" << k << " += dVdV_" << i << "_" << j << "*dValuedParam_" << j <<"_" << k << ";\n";
for (int j = 0; j < valueParamDerivExpressions[i].size(); j++)
reductionSource << "global_dValuedParam_" << i << "_" << j << "[index] = dValuedParam_" << i << "_" << j << ";\n";
}
}
map<string, string> replacements; map<string, string> replacements;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["REDUCE_PARAM0_DERIV"] = deriv0.str();
replacements["COMPUTE_VALUES"] = reductionSource.str(); replacements["COMPUTE_VALUES"] = reductionSource.str();
map<string, string> defines; map<string, string> defines;
defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms()); defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
...@@ -3313,6 +3406,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3313,6 +3406,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
} }
} }
} }
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
n2EnergyExpressions["energyParamDeriv"+cl.intToString(j)+" += interactionScale*"] = energyParamDerivExpressions[i][j];
if (exclude) if (exclude)
n2EnergySource << "if (!isExcluded) {\n"; n2EnergySource << "if (!isExcluded) {\n";
n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp"); n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp");
...@@ -3322,7 +3417,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3322,7 +3417,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
map<string, string> replacements; map<string, string> replacements;
string n2EnergyStr = n2EnergySource.str(); string n2EnergyStr = n2EnergySource.str();
replacements["COMPUTE_INTERACTION"] = n2EnergyStr; replacements["COMPUTE_INTERACTION"] = n2EnergyStr;
stringstream extraArgs, loadLocal1, loadLocal2, clearLocal, load1, load2, declare1, recordDeriv, storeDerivs1, storeDerivs2, declareTemps, setTemps; stringstream extraArgs, loadLocal1, loadLocal2, clearLocal, load1, load2, declare1, recordDeriv, storeDerivs1, storeDerivs2, declareTemps, setTemps, initParamDerivs, saveParamDerivs;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", __global const float* globals"; extraArgs << ", __global const float* globals";
pairEnergyUsesParam.resize(params->getBuffers().size(), false); pairEnergyUsesParam.resize(params->getBuffers().size(), false);
...@@ -3381,6 +3476,17 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3381,6 +3476,17 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
setTemps << "tempDerivBuffer" << index << "[get_local_id(0)] = deriv" << index << "_1;\n"; setTemps << "tempDerivBuffer" << index << "[get_local_id(0)] = deriv" << index << "_1;\n";
} }
} }
if (needEnergyParamDerivs) {
extraArgs << ", __global mixed* restrict energyParamDerivs";
const vector<string>& allParamDerivNames = cl.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
saveParamDerivs << "energyParamDerivs[get_global_id(0)*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
}
}
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
...@@ -3393,6 +3499,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3393,6 +3499,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
replacements["STORE_DERIVATIVES_2"] = storeDerivs2.str(); replacements["STORE_DERIVATIVES_2"] = storeDerivs2.str();
replacements["DECLARE_TEMP_BUFFERS"] = declareTemps.str(); replacements["DECLARE_TEMP_BUFFERS"] = declareTemps.str();
replacements["SET_TEMP_BUFFERS"] = setTemps.str(); replacements["SET_TEMP_BUFFERS"] = setTemps.str();
replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
if (useCutoff) if (useCutoff)
pairEnergyDefines["USE_CUTOFF"] = "1"; pairEnergyDefines["USE_CUTOFF"] = "1";
if (usePeriodic) if (usePeriodic)
...@@ -3415,7 +3523,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3415,7 +3523,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
{ {
// Create the kernel to reduce the derivatives and calculate per-particle energy terms. // Create the kernel to reduce the derivatives and calculate per-particle energy terms.
stringstream compute, extraArgs, reduce; stringstream compute, extraArgs, reduce, initParamDerivs, saveParamDerivs;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", __global const float* globals"; extraArgs << ", __global const float* globals";
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
...@@ -3449,6 +3557,17 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3449,6 +3557,17 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++)
reduce << "REDUCE_VALUE(derivBuffers" << cl.intToString(i+1) << ", " << energyDerivs->getBuffers()[i].getType() << ")\n"; reduce << "REDUCE_VALUE(derivBuffers" << cl.intToString(i+1) << ", " << energyDerivs->getBuffers()[i].getType() << ")\n";
} }
if (needEnergyParamDerivs) {
extraArgs << ", __global mixed* restrict energyParamDerivs";
const vector<string>& allParamDerivNames = cl.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
saveParamDerivs << "energyParamDerivs[get_global_id(0)*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
}
}
// Compute the various expressions. // Compute the various expressions.
...@@ -3482,6 +3601,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3482,6 +3601,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
expressions["/*"+cl.intToString(i+1)+"*/ force.y -= "] = grady; expressions["/*"+cl.intToString(i+1)+"*/ force.y -= "] = grady;
if (!isZeroExpression(gradz)) if (!isZeroExpression(gradz))
expressions["/*"+cl.intToString(i+1)+"*/ force.z -= "] = gradz; expressions["/*"+cl.intToString(i+1)+"*/ force.z -= "] = gradz;
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
expressions["/*"+cl.intToString(i+1)+"*/ energyParamDeriv"+cl.intToString(j)+" += "] = energyParamDerivExpressions[i][j];
} }
for (int i = 1; i < force.getNumComputedValues(); i++) for (int i = 1; i < force.getNumComputedValues(); i++)
for (int j = 0; j < i; j++) for (int j = 0; j < i; j++)
...@@ -3510,16 +3631,19 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3510,16 +3631,19 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["REDUCE_DERIVATIVES"] = reduce.str(); replacements["REDUCE_DERIVATIVES"] = reduce.str();
replacements["COMPUTE_ENERGY"] = compute.str(); replacements["COMPUTE_ENERGY"] = compute.str();
replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
map<string, string> defines; map<string, string> defines;
defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms()); defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms()); defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBEnergyPerParticle, replacements), defines); cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBEnergyPerParticle, replacements), defines);
perParticleEnergyKernel = cl::Kernel(program, "computePerParticleEnergy"); perParticleEnergyKernel = cl::Kernel(program, "computePerParticleEnergy");
} }
if (needParameterGradient) { if (needParameterGradient || needEnergyParamDerivs) {
// Create the kernel to compute chain rule terms for computed values that depend explicitly on particle coordinates. // Create the kernel to compute chain rule terms for computed values that depend explicitly on particle coordinates, and for
// derivatives with respect to global parameters.
stringstream compute, extraArgs; stringstream compute, extraArgs, initParamDerivs, saveParamDerivs;
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0)
extraArgs << ", __global const float* globals"; extraArgs << ", __global const float* globals";
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
...@@ -3538,6 +3662,19 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3538,6 +3662,19 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
extraArgs << ", __global " << buffer.getType() << "* restrict derivBuffers" << index; extraArgs << ", __global " << buffer.getType() << "* restrict derivBuffers" << index;
compute << buffer.getType() << " deriv" << index << " = derivBuffers" << index << "[index];\n"; compute << buffer.getType() << " deriv" << index << " = derivBuffers" << index << "[index];\n";
} }
if (needEnergyParamDerivs) {
extraArgs << ", __global mixed* restrict energyParamDerivs";
const vector<string>& allParamDerivNames = cl.getEnergyParamDerivNames();
int numDerivs = allParamDerivNames.size();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
for (int j = 0; j < dValuedParam[i]->getBuffers().size(); j++)
extraArgs << ", __global real* restrict dValuedParam_" << j << "_" << i;
initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
saveParamDerivs << "energyParamDerivs[get_global_id(0)*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
}
}
map<string, string> variables; map<string, string> variables;
variables["x"] = "pos.x"; variables["x"] = "pos.x";
variables["y"] = "pos.y"; variables["y"] = "pos.y";
...@@ -3548,34 +3685,40 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3548,34 +3685,40 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables[force.getGlobalParameterName(i)] = "globals["+cl.intToString(i)+"]"; variables[force.getGlobalParameterName(i)] = "globals["+cl.intToString(i)+"]";
for (int i = 0; i < force.getNumComputedValues(); i++) for (int i = 0; i < force.getNumComputedValues(); i++)
variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]"); variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
for (int i = 1; i < force.getNumComputedValues(); i++) { if (needParameterGradient) {
string is = cl.intToString(i); for (int i = 1; i < force.getNumComputedValues(); i++) {
compute << "real4 dV"<<is<<"dR = (real4) 0;\n"; string is = cl.intToString(i);
for (int j = 1; j < i; j++) { compute << "real4 dV"<<is<<"dR = (real4) 0;\n";
if (!isZeroExpression(valueDerivExpressions[i][j])) { for (int j = 1; j < i; j++) {
map<string, Lepton::ParsedExpression> derivExpressions; if (!isZeroExpression(valueDerivExpressions[i][j])) {
string js = cl.intToString(j); map<string, Lepton::ParsedExpression> derivExpressions;
derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j]; string js = cl.intToString(j);
compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js); derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n"; compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js);
compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
}
} }
map<string, Lepton::ParsedExpression> gradientExpressions;
if (!isZeroExpression(valueGradientExpressions[i][0]))
gradientExpressions["dV"+is+"dR.x += "] = valueGradientExpressions[i][0];
if (!isZeroExpression(valueGradientExpressions[i][1]))
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp");
} }
map<string, Lepton::ParsedExpression> gradientExpressions; for (int i = 1; i < force.getNumComputedValues(); i++)
if (!isZeroExpression(valueGradientExpressions[i][0])) compute << "force -= deriv"<<energyDerivs->getParameterSuffix(i)<<"*dV"<<i<<"dR;\n";
gradientExpressions["dV"+is+"dR.x += "] = valueGradientExpressions[i][0];
if (!isZeroExpression(valueGradientExpressions[i][1]))
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp");
}
for (int i = 1; i < force.getNumComputedValues(); i++) {
string is = cl.intToString(i);
compute << "force -= deriv"<<energyDerivs->getParameterSuffix(i)<<"*dV"<<is<<"dR;\n";
} }
if (needEnergyParamDerivs)
for (int i = 0; i < force.getNumComputedValues(); i++)
for (int j = 0; j < dValuedParam.size(); j++)
compute << "energyParamDeriv"<<j<<" += deriv"<<energyDerivs->getParameterSuffix(i)<<"*dValuedParam_"<<i<<"_"<<j<<"[index];\n";
map<string, string> replacements; map<string, string> replacements;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["COMPUTE_FORCES"] = compute.str(); replacements["COMPUTE_FORCES"] = compute.str();
replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
map<string, string> defines; map<string, string> defines;
defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms()); defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBGradientChainRule, replacements), defines); cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBGradientChainRule, replacements), defines);
...@@ -3749,6 +3892,10 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3749,6 +3892,10 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL); pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
} }
} }
for (int i = 0; i < dValue0dParam.size(); i++) {
pairValueKernel.setArg<cl::Buffer>(index++, dValue0dParam[i]->getDeviceBuffer());
pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*dValue0dParam[i]->getElementSize(), NULL);
}
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
pairValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer()); pairValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
index = 0; index = 0;
...@@ -3762,6 +3909,11 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3762,6 +3909,11 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
perParticleValueKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory()); perParticleValueKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory());
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) for (int i = 0; i < (int) computedValues->getBuffers().size(); i++)
perParticleValueKernel.setArg<cl::Memory>(index++, computedValues->getBuffers()[i].getMemory()); perParticleValueKernel.setArg<cl::Memory>(index++, computedValues->getBuffers()[i].getMemory());
for (int i = 0; i < dValuedParam.size(); i++) {
perParticleValueKernel.setArg<cl::Memory>(index++, dValue0dParam[i]->getDeviceBuffer());
for (int j = 0; j < dValuedParam[i]->getBuffers().size(); j++)
perParticleValueKernel.setArg<cl::Memory>(index++, dValuedParam[i]->getBuffers()[j].getMemory());
}
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
perParticleValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer()); perParticleValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
index = 0; index = 0;
...@@ -3811,6 +3963,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3811,6 +3963,8 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL); pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
} }
} }
if (needEnergyParamDerivs)
pairEnergyKernel.setArg<cl::Memory>(index++, cl.getEnergyParamDerivBuffer().getDeviceBuffer());
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
pairEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer()); pairEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
index = 0; index = 0;
...@@ -3831,9 +3985,11 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3831,9 +3985,11 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
perParticleEnergyKernel.setArg<cl::Memory>(index++, energyDerivChain->getBuffers()[i].getMemory()); perParticleEnergyKernel.setArg<cl::Memory>(index++, energyDerivChain->getBuffers()[i].getMemory());
if (useLong) if (useLong)
perParticleEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs->getDeviceBuffer()); perParticleEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs->getDeviceBuffer());
if (needEnergyParamDerivs)
perParticleEnergyKernel.setArg<cl::Memory>(index++, cl.getEnergyParamDerivBuffer().getDeviceBuffer());
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
perParticleEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer()); perParticleEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
if (needParameterGradient) { if (needParameterGradient || needEnergyParamDerivs) {
index = 0; index = 0;
gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer()); gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer()); gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
...@@ -3845,6 +4001,12 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3845,6 +4001,12 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
gradientChainRuleKernel.setArg<cl::Memory>(index++, computedValues->getBuffers()[i].getMemory()); gradientChainRuleKernel.setArg<cl::Memory>(index++, computedValues->getBuffers()[i].getMemory());
for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++)
gradientChainRuleKernel.setArg<cl::Memory>(index++, energyDerivs->getBuffers()[i].getMemory()); gradientChainRuleKernel.setArg<cl::Memory>(index++, energyDerivs->getBuffers()[i].getMemory());
if (needEnergyParamDerivs) {
gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getEnergyParamDerivBuffer().getDeviceBuffer());
for (int i = 0; i < dValuedParam.size(); i++)
for (int j = 0; j < dValuedParam[i]->getBuffers().size(); j++)
gradientChainRuleKernel.setArg<cl::Memory>(index++, dValuedParam[i]->getBuffers()[j].getMemory());
}
} }
} }
if (globals != NULL) { if (globals != NULL) {
...@@ -3875,7 +4037,7 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3875,7 +4037,7 @@ double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
cl.executeKernel(perParticleValueKernel, cl.getPaddedNumAtoms()); cl.executeKernel(perParticleValueKernel, cl.getPaddedNumAtoms());
cl.executeKernel(pairEnergyKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize()); cl.executeKernel(pairEnergyKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
cl.executeKernel(perParticleEnergyKernel, cl.getPaddedNumAtoms()); cl.executeKernel(perParticleEnergyKernel, cl.getPaddedNumAtoms());
if (needParameterGradient) if (needParameterGradient || needEnergyParamDerivs)
cl.executeKernel(gradientChainRuleKernel, cl.getPaddedNumAtoms()); cl.executeKernel(gradientChainRuleKernel, cl.getPaddedNumAtoms());
return 0.0; return 0.0;
} }
......
...@@ -605,7 +605,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -605,7 +605,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
} }
} }
if (energyParameterDerivatives.size() > 0) if (energyParameterDerivatives.size() > 0)
args << ", __global mixed* energyParamDerivs"; args << ", __global mixed* restrict energyParamDerivs";
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
stringstream loadLocal1; stringstream loadLocal1;
for (int i = 0; i < (int) params.size(); i++) { for (int i = 0; i < (int) params.size(); i++) {
...@@ -666,7 +666,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -666,7 +666,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
for (int i = 0; i < energyParameterDerivatives.size(); i++) for (int i = 0; i < energyParameterDerivatives.size(); i++)
for (int index = 0; index < numDerivs; index++) for (int index = 0; index < numDerivs; index++)
if (allParamDerivNames[index] == energyParameterDerivatives[i]) if (allParamDerivNames[index] == energyParameterDerivatives[i])
saveDerivs<<"energyParamDerivs[get_global_id(0)*"<<numDerivs<<"+"<<i<<"] += energyParamDeriv"<<i<<";\n"; saveDerivs<<"energyParamDerivs[get_global_id(0)*"<<numDerivs<<"+"<<index<<"] += energyParamDeriv"<<i<<";\n";
replacements["SAVE_DERIVATIVES"] = saveDerivs.str(); replacements["SAVE_DERIVATIVES"] = saveDerivs.str();
map<string, string> defines; map<string, string> defines;
if (useCutoff) if (useCutoff)
......
...@@ -32,6 +32,7 @@ __kernel void computeN2Energy( ...@@ -32,6 +32,7 @@ __kernel void computeN2Energy(
const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1); const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
const unsigned int tbx = get_local_id(0) - tgx; const unsigned int tbx = get_local_id(0) - tgx;
mixed energy = 0; mixed energy = 0;
INIT_PARAM_DERIVS
// First loop: process tiles that contain exclusions. // First loop: process tiles that contain exclusions.
...@@ -73,6 +74,7 @@ __kernel void computeN2Energy( ...@@ -73,6 +74,7 @@ __kernel void computeN2Energy(
atom2 = y*TILE_SIZE+j; atom2 = y*TILE_SIZE+j;
real dEdR = 0; real dEdR = 0;
real tempEnergy = 0; real tempEnergy = 0;
const real interactionScale = 0.5f;
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
bool isExcluded = !(excl & 0x1); bool isExcluded = !(excl & 0x1);
#endif #endif
...@@ -123,6 +125,7 @@ __kernel void computeN2Energy( ...@@ -123,6 +125,7 @@ __kernel void computeN2Energy(
atom2 = y*TILE_SIZE+tj; atom2 = y*TILE_SIZE+tj;
real dEdR = 0; real dEdR = 0;
real tempEnergy = 0; real tempEnergy = 0;
const real interactionScale = 1.0f;
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
bool isExcluded = !(excl & 0x1); bool isExcluded = !(excl & 0x1);
#endif #endif
...@@ -281,6 +284,7 @@ __kernel void computeN2Energy( ...@@ -281,6 +284,7 @@ __kernel void computeN2Energy(
atom2 = atomIndices[tbx+tj]; atom2 = atomIndices[tbx+tj];
real dEdR = 0; real dEdR = 0;
real tempEnergy = 0; real tempEnergy = 0;
const real interactionScale = 1.0f;
if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) { if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
COMPUTE_INTERACTION COMPUTE_INTERACTION
dEdR /= -r; dEdR /= -r;
...@@ -319,6 +323,7 @@ __kernel void computeN2Energy( ...@@ -319,6 +323,7 @@ __kernel void computeN2Energy(
atom2 = atomIndices[tbx+tj]; atom2 = atomIndices[tbx+tj];
real dEdR = 0; real dEdR = 0;
real tempEnergy = 0; real tempEnergy = 0;
const real interactionScale = 1.0f;
if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) { if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
COMPUTE_INTERACTION COMPUTE_INTERACTION
dEdR /= -r; dEdR /= -r;
...@@ -373,4 +378,5 @@ __kernel void computeN2Energy( ...@@ -373,4 +378,5 @@ __kernel void computeN2Energy(
pos++; pos++;
} }
energyBuffer[get_global_id(0)] += energy; energyBuffer[get_global_id(0)] += energy;
SAVE_PARAM_DERIVS
} }
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
__kernel void computePerParticleEnergy(int bufferSize, int numBuffers, __global real4* restrict forceBuffers, __global mixed* restrict energyBuffer, __global const real4* restrict posq __kernel void computePerParticleEnergy(int bufferSize, int numBuffers, __global real4* restrict forceBuffers, __global mixed* restrict energyBuffer, __global const real4* restrict posq
PARAMETER_ARGUMENTS) { PARAMETER_ARGUMENTS) {
mixed energy = 0; mixed energy = 0;
INIT_PARAM_DERIVS
unsigned int index = get_global_id(0); unsigned int index = get_global_id(0);
while (index < NUM_ATOMS) { while (index < NUM_ATOMS) {
// Reduce the derivatives // Reduce the derivatives
...@@ -27,4 +28,5 @@ __kernel void computePerParticleEnergy(int bufferSize, int numBuffers, __global ...@@ -27,4 +28,5 @@ __kernel void computePerParticleEnergy(int bufferSize, int numBuffers, __global
index += get_global_size(0); index += get_global_size(0);
} }
energyBuffer[get_global_id(0)] += energy; energyBuffer[get_global_id(0)] += energy;
SAVE_PARAM_DERIVS
} }
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
__kernel void computeGradientChainRuleTerms(__global real4* restrict forceBuffers, __global const real4* restrict posq __kernel void computeGradientChainRuleTerms(__global real4* restrict forceBuffers, __global const real4* restrict posq
PARAMETER_ARGUMENTS) { PARAMETER_ARGUMENTS) {
INIT_PARAM_DERIVS
unsigned int index = get_global_id(0); unsigned int index = get_global_id(0);
while (index < NUM_ATOMS) { while (index < NUM_ATOMS) {
real4 pos = posq[index]; real4 pos = posq[index];
...@@ -12,4 +13,5 @@ __kernel void computeGradientChainRuleTerms(__global real4* restrict forceBuffer ...@@ -12,4 +13,5 @@ __kernel void computeGradientChainRuleTerms(__global real4* restrict forceBuffer
forceBuffers[index] = force; forceBuffers[index] = force;
index += get_global_size(0); index += get_global_size(0);
} }
SAVE_PARAM_DERIVS
} }
...@@ -74,6 +74,7 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4* ...@@ -74,6 +74,7 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4*
COMPUTE_VALUE COMPUTE_VALUE
} }
value += tempValue1; value += tempValue1;
ADD_TEMP_DERIVS1
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
} }
#endif #endif
...@@ -123,6 +124,8 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4* ...@@ -123,6 +124,8 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4*
} }
value += tempValue1; value += tempValue1;
local_value[tbx+tj] += tempValue2; local_value[tbx+tj] += tempValue2;
ADD_TEMP_DERIVS1
ADD_TEMP_DERIVS2
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
} }
#endif #endif
...@@ -137,18 +140,23 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4* ...@@ -137,18 +140,23 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4*
// Write results. // Write results.
#ifdef SUPPORTS_64_BIT_ATOMICS #ifdef SUPPORTS_64_BIT_ATOMICS
unsigned int offset = x*TILE_SIZE + tgx; unsigned int offset1 = x*TILE_SIZE + tgx;
atom_add(&global_value[offset], (long) (value*0x100000000)); atom_add(&global_value[offset1], (long) (value*0x100000000));
STORE_PARAM_DERIVS1
if (x != y) { if (x != y) {
offset = y*TILE_SIZE + tgx; unsigned int offset2 = y*TILE_SIZE + tgx;
atom_add(&global_value[offset], (long) (local_value[get_local_id(0)]*0x100000000)); atom_add(&global_value[offset2], (long) (local_value[get_local_id(0)]*0x100000000));
STORE_PARAM_DERIVS2
} }
#else #else
unsigned int offset1 = x*TILE_SIZE + tgx + warp*PADDED_NUM_ATOMS; unsigned int offset1 = x*TILE_SIZE + tgx + warp*PADDED_NUM_ATOMS;
unsigned int offset2 = y*TILE_SIZE + tgx + warp*PADDED_NUM_ATOMS; unsigned int offset2 = y*TILE_SIZE + tgx + warp*PADDED_NUM_ATOMS;
global_value[offset1] += value; global_value[offset1] += value;
if (x != y) STORE_PARAM_DERIVS1
if (x != y) {
global_value[offset2] += local_value[get_local_id(0)]; global_value[offset2] += local_value[get_local_id(0)];
STORE_PARAM_DERIVS2
}
#endif #endif
} }
...@@ -292,6 +300,8 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4* ...@@ -292,6 +300,8 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4*
} }
value += tempValue1; value += tempValue1;
local_value[tbx+tj] += tempValue2; local_value[tbx+tj] += tempValue2;
ADD_TEMP_DERIVS1
ADD_TEMP_DERIVS2
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
} }
#endif #endif
...@@ -308,15 +318,23 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4* ...@@ -308,15 +318,23 @@ __kernel void computeN2Value(__global const real4* restrict posq, __local real4*
unsigned int atom2 = y*TILE_SIZE + tgx; unsigned int atom2 = y*TILE_SIZE + tgx;
#endif #endif
#ifdef SUPPORTS_64_BIT_ATOMICS #ifdef SUPPORTS_64_BIT_ATOMICS
atom_add(&global_value[atom1], (long) (value*0x100000000)); unsigned in offset1 = atom1;
if (atom2 < PADDED_NUM_ATOMS) atom_add(&global_value[offset1], (long) (value*0x100000000));
atom_add(&global_value[atom2], (long) (local_value[get_local_id(0)]*0x100000000)); STORE_PARAM_DERIVS1
if (atom2 < PADDED_NUM_ATOMS) {
unsigned int offset2 = atom2;
atom_add(&global_value[offset2], (long) (local_value[get_local_id(0)]*0x100000000));
STORE_PARAM_DERIVS2
}
#else #else
unsigned int offset1 = atom1 + warp*PADDED_NUM_ATOMS; unsigned int offset1 = atom1 + warp*PADDED_NUM_ATOMS;
unsigned int offset2 = atom2 + warp*PADDED_NUM_ATOMS;
global_value[offset1] += value; global_value[offset1] += value;
if (atom2 < PADDED_NUM_ATOMS) STORE_PARAM_DERIVS1
if (atom2 < PADDED_NUM_ATOMS) {
unsigned int offset2 = atom2 + warp*PADDED_NUM_ATOMS;
global_value[offset2] += local_value[get_local_id(0)]; global_value[offset2] += local_value[get_local_id(0)];
STORE_PARAM_DERIVS2
}
#endif #endif
} }
pos++; pos++;
......
...@@ -21,6 +21,7 @@ __kernel void computePerParticleValues(int bufferSize, int numBuffers, __global ...@@ -21,6 +21,7 @@ __kernel void computePerParticleValues(int bufferSize, int numBuffers, __global
for (int i = index+bufferSize; i < totalSize; i += bufferSize) for (int i = index+bufferSize; i < totalSize; i += bufferSize)
sum += valueBuffers[i]; sum += valueBuffers[i];
#endif #endif
REDUCE_PARAM0_DERIV
// Now calculate other values // Now calculate other values
......
...@@ -187,8 +187,8 @@ void testEnergyParameterDerivatives() { ...@@ -187,8 +187,8 @@ void testEnergyParameterDerivatives() {
CustomBondForce* bonds = new CustomBondForce("k*(r-r0)^2"); CustomBondForce* bonds = new CustomBondForce("k*(r-r0)^2");
bonds->addGlobalParameter("r0", 0.0); bonds->addGlobalParameter("r0", 0.0);
bonds->addGlobalParameter("k", 0.0); bonds->addGlobalParameter("k", 0.0);
bonds->addEnergyParameterDerivative("r0");
bonds->addEnergyParameterDerivative("k"); bonds->addEnergyParameterDerivative("k");
bonds->addEnergyParameterDerivative("r0");
vector<double> parameters; vector<double> parameters;
bonds->addBond(0, 1, parameters); bonds->addBond(0, 1, parameters);
bonds->addBond(1, 2, parameters); bonds->addBond(1, 2, parameters);
......
...@@ -505,10 +505,10 @@ void testEnergyParameterDerivatives() { ...@@ -505,10 +505,10 @@ void testEnergyParameterDerivatives() {
force->addComputedValue("b", "a+B", CustomGBForce::SingleParticle); force->addComputedValue("b", "a+B", CustomGBForce::SingleParticle);
force->addEnergyTerm("C*(a1+b1+a2+b2+r)^0.8", CustomGBForce::ParticlePair); force->addEnergyTerm("C*(a1+b1+a2+b2+r)^0.8", CustomGBForce::ParticlePair);
force->addEnergyTerm("(D-B)*b", CustomGBForce::SingleParticle); force->addEnergyTerm("(D-B)*b", CustomGBForce::SingleParticle);
for (int i = 0; i < numParameters; i++) { for (int i = 0; i < numParameters; i++)
force->addGlobalParameter(paramNames[i], paramValues[i]); force->addGlobalParameter(paramNames[i], paramValues[i]);
for (int i = numParameters-1; i >= 0; i--)
force->addEnergyParameterDerivative(paramNames[i]); force->addEnergyParameterDerivative(paramNames[i]);
}
force->setNonbondedMethod(CustomGBForce::CutoffPeriodic); force->setNonbondedMethod(CustomGBForce::CutoffPeriodic);
force->setCutoffDistance(1.0); force->setCutoffDistance(1.0);
vector<Vec3> positions; vector<Vec3> positions;
......
...@@ -1050,8 +1050,8 @@ void testEnergyParameterDerivatives() { ...@@ -1050,8 +1050,8 @@ void testEnergyParameterDerivatives() {
CustomNonbondedForce* nonbonded = new CustomNonbondedForce("k*(r-r0)^2"); CustomNonbondedForce* nonbonded = new CustomNonbondedForce("k*(r-r0)^2");
nonbonded->addGlobalParameter("r0", 0.0); nonbonded->addGlobalParameter("r0", 0.0);
nonbonded->addGlobalParameter("k", 0.0); nonbonded->addGlobalParameter("k", 0.0);
nonbonded->addEnergyParameterDerivative("r0");
nonbonded->addEnergyParameterDerivative("k"); nonbonded->addEnergyParameterDerivative("k");
nonbonded->addEnergyParameterDerivative("r0");
vector<double> parameters; vector<double> parameters;
nonbonded->addParticle(parameters); nonbonded->addParticle(parameters);
nonbonded->addParticle(parameters); nonbonded->addParticle(parameters);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment