Commit e82ace44 authored by peastman's avatar peastman
Browse files

OpenCL implementation of tabulated functions for CustomIntegrator

parent 094e6c71
......@@ -1472,7 +1472,9 @@ private:
class ReorderListener;
class GlobalTarget;
class DerivFunction;
std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const std::string& forceName, const std::string& energyName);
std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator,
const std::string& forceName, const std::string& energyName, std::vector<const TabulatedFunction*>& functions,
std::vector<std::pair<std::string, std::string> >& functionNames);
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes);
......@@ -1491,6 +1493,7 @@ private:
OpenCLArray* uniformRandoms;
OpenCLArray* randomSeed;
OpenCLArray* perDofEnergyParamDerivs;
std::vector<OpenCLArray*> tabulatedFunctions;
std::map<int, OpenCLArray*> savedForces;
std::set<int> validSavedForces;
OpenCLParameterSet* perDofValues;
......
......@@ -48,6 +48,7 @@
#include "lepton/Operation.h"
#include "lepton/Parser.h"
#include "lepton/ParsedExpression.h"
#include "ReferenceTabulatedFunction.h"
#include "SimTKOpenMMRealType.h"
#include "SimTKOpenMMUtilities.h"
#include <algorithm>
......@@ -7408,6 +7409,8 @@ OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() {
delete perDofEnergyParamDerivs;
if (perDofValues != NULL)
delete perDofValues;
for (auto function : tabulatedFunctions)
delete function;
for (auto& f : savedForces)
delete f.second;
}
......@@ -7424,7 +7427,8 @@ void OpenCLIntegrateCustomStepKernel::initialize(const System& system, const Cus
SimTKOpenMMUtilities::setRandomNumberSeed(integrator.getRandomNumberSeed());
}
string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) {
string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator,
const string& forceName, const string& energyName, vector<const TabulatedFunction*>& functions, vector<pair<string, string> >& functionNames) {
const string suffixes[] = {".x", ".y", ".z"};
string suffix = suffixes[component];
map<string, Lepton::ParsedExpression> expressions;
......@@ -7457,8 +7461,6 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va
variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "globals["+cl.intToString(parameterVariableIndex[i])+"]";
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float");
vector<pair<ExpressionTreeNode, string> > variableNodes;
findExpressionsForDerivs(expr.getRootNode(), variableNodes);
......@@ -7492,13 +7494,34 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
map<string, string> defines;
defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
defines["WORK_GROUP_SIZE"] = cl.intToString(OpenCLContext::ThreadBlockSize);
// Record the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionNames;
vector<const TabulatedFunction*> functionList;
vector<string> tableTypes;
for (int i = 0; i < integrator.getNumTabulatedFunctions(); i++) {
functionList.push_back(&integrator.getTabulatedFunction(i));
string name = integrator.getTabulatedFunctionName(i);
string arrayName = "table"+cl.intToString(i);
functionNames.push_back(make_pair(name, arrayName));
functions[name] = createReferenceTabulatedFunction(integrator.getTabulatedFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(integrator.getTabulatedFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
if (width == 1)
tableTypes.push_back("float");
else
tableTypes.push_back("float"+cl.intToString(width));
}
// Record information about all the computation steps.
vector<string> variable(numSteps);
vector<int> forceGroup;
vector<vector<Lepton::ParsedExpression> > expression;
map<string, Lepton::CustomFunction*> functions;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup, functions);
for (int step = 0; step < numSteps; step++) {
string expr;
......@@ -7670,7 +7693,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
if (numUniform > 0)
compute << "float4 uniform = uniformValues[uniformIndex+index];\n";
for (int i = 0; i < 3; i++)
compute << createPerDofComputation(stepType[j] == CustomIntegrator::ComputePerDof ? variable[j] : "", expression[j][0], i, integrator, forceName[j], energyName[j]);
compute << createPerDofComputation(stepType[j] == CustomIntegrator::ComputePerDof ? variable[j] : "", expression[j][0], i, integrator, forceName[j], energyName[j], functionList, functionNames);
if (variable[j] == "x") {
if (storePosAsDelta[j]) {
if (cl.getSupportsDoublePrecision())
......@@ -7705,6 +7728,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
string valueName = "perDofValues"+cl.intToString(i+1);
args << ", __global " << buffer.getType() << "* restrict " << valueName;
}
for (int i = 0; i < (int) tableTypes.size(); i++)
args << ", __global const " << tableTypes[i]<< "* restrict table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str();
if (loadPosAsDelta[step])
defines["LOAD_POS_AS_DELTA"] = "1";
......@@ -7728,6 +7753,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
kernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs->getDeviceBuffer());
for (auto& buffer : perDofValues->getBuffers())
kernel.setArg<cl::Memory>(index++, buffer.getMemory());
for (auto* array : tabulatedFunctions)
kernel.setArg<cl::Buffer>(index++, array->getDeviceBuffer());
if (stepType[step] == CustomIntegrator::ComputeSum) {
// Create a second kernel for this step that sums the values.
......@@ -7790,7 +7817,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
}
Lepton::ParsedExpression keExpression = Lepton::Parser::parse(integrator.getKineticEnergyExpression()).optimize();
for (int i = 0; i < 3; i++)
computeKE << createPerDofComputation("", keExpression, i, integrator, "f", "");
computeKE << createPerDofComputation("", keExpression, i, integrator, "f", "", functionList, functionNames);
map<string, string> replacements;
replacements["COMPUTE_STEP"] = computeKE.str();
stringstream args;
......@@ -7799,6 +7826,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
string valueName = "perDofValues"+cl.intToString(i+1);
args << ", __global " << buffer.getType() << "* restrict " << valueName;
}
for (int i = 0; i < (int) tableTypes.size(); i++)
args << ", __global const " << tableTypes[i]<< "* restrict table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str();
if (defines.find("LOAD_POS_AS_DELTA") != defines.end())
defines.erase("LOAD_POS_AS_DELTA");
......@@ -7822,6 +7851,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
kineticEnergyKernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs->getDeviceBuffer());
for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++)
kineticEnergyKernel.setArg<cl::Memory>(index++, perDofValues->getBuffers()[i].getMemory());
for (auto* array : tabulatedFunctions)
kineticEnergyKernel.setArg<cl::Buffer>(index++, array->getDeviceBuffer());
keNeedsForce = usesVariable(keExpression, "f");
// Create a second kernel to sum the values.
......@@ -7832,8 +7863,13 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
sumKineticEnergyKernel.setArg<cl::Buffer>(index++, sumBuffer->getDeviceBuffer());
sumKineticEnergyKernel.setArg<cl::Buffer>(index++, summedValue->getDeviceBuffer());
sumKineticEnergyKernel.setArg<cl_int>(index++, 3*numAtoms);
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
}
// Make sure all values (variables, parameters, etc.) are up to date.
if (!deviceValuesAreCurrent) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment