Commit 72d59cbe authored by Peter Eastman's avatar Peter Eastman
Browse files

Finished OpenCL implementation of CustomNonbondedForce. Also implemented a few optimizations.

parent 2127b8dd
......@@ -84,6 +84,8 @@ public:
ExpressionTreeNode(const ExpressionTreeNode& node);
ExpressionTreeNode();
~ExpressionTreeNode();
bool operator==(const ExpressionTreeNode& node) const;
bool operator!=(const ExpressionTreeNode& node) const;
ExpressionTreeNode& operator=(const ExpressionTreeNode& node);
/**
* Get the Operation performed by this node.
......
......@@ -95,6 +95,12 @@ public:
* @param variable the variable with respect to which the derivate should be taken
*/
virtual ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const = 0;
virtual bool operator!=(const Operation& op) const {
return op.getId() != getId();
}
virtual bool operator==(const Operation& op) const {
return !(*this != op);
}
class Constant;
class Variable;
class Custom;
......@@ -149,6 +155,10 @@ public:
double getValue() const {
return value;
}
bool operator!=(const Operation& op) const {
const Constant* o = dynamic_cast<const Constant*>(&op);
return (o == NULL || o->value != value);
}
private:
double value;
};
......@@ -176,6 +186,10 @@ public:
return iter->second;
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
bool operator!=(const Operation& op) const {
const Variable* o = dynamic_cast<const Variable*>(&op);
return (o == NULL || o->name != name);
}
private:
std::string name;
};
......@@ -214,6 +228,10 @@ public:
const std::vector<int>& getDerivOrder() const {
return derivOrder;
}
bool operator!=(const Operation& op) const {
const Custom* o = dynamic_cast<const Custom*>(&op);
return (o == NULL || o->name != name || o->isDerivative != isDerivative || o->derivOrder != derivOrder);
}
private:
std::string name;
CustomFunction* function;
......@@ -708,6 +726,10 @@ public:
double getValue() const {
return value;
}
bool operator!=(const Operation& op) const {
const AddConstant* o = dynamic_cast<const AddConstant*>(&op);
return (o == NULL || o->value != value);
}
private:
double value;
};
......@@ -737,6 +759,10 @@ public:
double getValue() const {
return value;
}
bool operator!=(const Operation& op) const {
const MultiplyConstant* o = dynamic_cast<const MultiplyConstant*>(&op);
return (o == NULL || o->value != value);
}
private:
double value;
};
......@@ -766,6 +792,10 @@ public:
double getValue() const {
return value;
}
bool operator!=(const Operation& op) const {
const PowerConstant* o = dynamic_cast<const PowerConstant*>(&op);
return (o == NULL || o->value != value);
}
private:
double value;
};
......
......@@ -48,6 +48,11 @@ class ExpressionProgram;
class LEPTON_EXPORT ParsedExpression {
public:
/**
* Create an uninitialized ParsedExpression. This exists so that ParsedExpressions can be put in STL containers.
* Doing anything with it will produce an exception.
*/
ParsedExpression();
/**
* Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class
* to parse expression.
......
......@@ -70,6 +70,19 @@ ExpressionTreeNode::~ExpressionTreeNode() {
delete operation;
}
bool ExpressionTreeNode::operator!=(const ExpressionTreeNode& node) const {
if (node.getOperation() != getOperation())
return true;
for (int i = 0; i < (int) getChildren().size(); i++)
if (getChildren()[i] != node.getChildren()[i])
return true;
return false;
}
bool ExpressionTreeNode::operator==(const ExpressionTreeNode& node) const {
return !(*this != node);
}
ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node) {
if (operation != NULL)
delete operation;
......
......@@ -38,10 +38,15 @@
using namespace Lepton;
using namespace std;
ParsedExpression::ParsedExpression() : rootNode(ExpressionTreeNode()) {
}
ParsedExpression::ParsedExpression(const ExpressionTreeNode& rootNode) : rootNode(rootNode) {
}
const ExpressionTreeNode& ParsedExpression::getRootNode() const {
if (&rootNode.getOperation() == NULL)
throw Exception("Illegal call to an initialized ParsedExpression");
return rootNode;
}
......
......@@ -27,7 +27,6 @@
#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "lepton/Operation.h"
#include <sstream>
using namespace OpenMM;
using namespace Lepton;
......@@ -46,69 +45,152 @@ static string intToString(int value) {
return s.str();
}
string OpenCLExpressionUtilities::createExpression(const ParsedExpression& expression, const map<string, string>& variables) {
return processExpression(expression.getRootNode(), variables);
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
stringstream out;
vector<pair<ExpressionTreeNode, string> > temps;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, variables, functions, prefix, functionParams);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
}
string OpenCLExpressionUtilities::processExpression(const ExpressionTreeNode& node, const map<string, string>& variables) {
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const map<string, string>& variables, const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, variables, functions, prefix, functionParams);
string name = prefix+intToString(temps.size());
out << "float " << name << " = ";
switch (node.getOperation().getId()) {
case Operation::CONSTANT:
return doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
out << doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
break;
case Operation::VARIABLE:
{
map<string, string>::const_iterator iter = variables.find(node.getOperation().getName());
if (iter == variables.end())
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
return iter->second;
out << iter->second;
break;
}
case Operation::CUSTOM:
{
int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++)
;
if (i == functions.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
out << "0.0f;\n";
out << "{\n";
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "int index = (int) (floor((x-params.x)*params.z));\n";
out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "x = (x-params.x)*params.z-index;\n";
if (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 0)
out << name << " = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));\n";
else
out << name << " = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;\n";
out << "}\n";
out << "}";
break;
}
case Operation::ADD:
return "("+processExpression(node.getChildren()[0], variables)+")+("+processExpression(node.getChildren()[1], variables)+")";
out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
break;
case Operation::SUBTRACT:
return "("+processExpression(node.getChildren()[0], variables)+")-("+processExpression(node.getChildren()[1], variables)+")";
out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
break;
case Operation::MULTIPLY:
return "("+processExpression(node.getChildren()[0], variables)+")*("+processExpression(node.getChildren()[1], variables)+")";
out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
break;
case Operation::DIVIDE:
return "("+processExpression(node.getChildren()[0], variables)+")/("+processExpression(node.getChildren()[1], variables)+")";
out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
break;
case Operation::POWER:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), ("+processExpression(node.getChildren()[1], variables)+"))";
out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::NEGATE:
return "-("+processExpression(node.getChildren()[0], variables)+")";
out << "-" << getTempName(node.getChildren()[0], temps);
break;
case Operation::SQRT:
return "sqrt("+processExpression(node.getChildren()[0], variables)+")";
out << "sqrt(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::EXP:
return "exp("+processExpression(node.getChildren()[0], variables)+")";
out << "exp(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::LOG:
return "log("+processExpression(node.getChildren()[0], variables)+")";
out << "log(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SIN:
return "sin("+processExpression(node.getChildren()[0], variables)+")";
out << "sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COS:
return "cos("+processExpression(node.getChildren()[0], variables)+")";
out << "cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SEC:
return "1.0f/cos("+processExpression(node.getChildren()[0], variables)+")";
out << "1.0f/cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::CSC:
return "1.0f/sin("+processExpression(node.getChildren()[0], variables)+")";
out << "1.0f/sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::TAN:
return "tan("+processExpression(node.getChildren()[0], variables)+")";
out << "tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COT:
return "1.0f/tan("+processExpression(node.getChildren()[0], variables)+")";
out << "1.0f/tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ASIN:
return "asin("+processExpression(node.getChildren()[0], variables)+")";
out << "asin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ACOS:
return "acos("+processExpression(node.getChildren()[0], variables)+")";
out << "acos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ATAN:
return "atan("+processExpression(node.getChildren()[0], variables)+")";
out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SQUARE:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), 2.0f)";
{
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg;
break;
}
case Operation::CUBE:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), 3.0f)";
{
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg << "*" << arg;
break;
}
case Operation::RECIPROCAL:
return "1.0f/("+processExpression(node.getChildren()[0], variables)+")";
out << "1.0f/" << getTempName(node.getChildren()[0], temps);
break;
case Operation::ADD_CONSTANT:
return doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue())+"+("+processExpression(node.getChildren()[0], variables)+")";
out << doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
break;
case Operation::MULTIPLY_CONSTANT:
return doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue())+"*("+processExpression(node.getChildren()[0], variables)+")";
out << doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
break;
case Operation::POWER_CONSTANT:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), "+doubleToString(dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue())+")";
out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << doubleToString(dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue()) << ")";
break;
default:
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
}
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
out << ";\n";
temps.push_back(make_pair(node, name));
}
string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return temps[i].second;
stringstream out;
out << "Internal error: No temporary variable for expression node: " << node;
throw OpenMMException(out.str());
}
......@@ -30,7 +30,9 @@
#include "lepton/ExpressionTreeNode.h"
#include "lepton/ParsedExpression.h"
#include <map>
#include <sstream>
#include <string>
#include <utility>
namespace OpenMM {
......@@ -41,9 +43,13 @@ namespace OpenMM {
class OpenCLExpressionUtilities {
public:
static std::string createExpression(const Lepton::ParsedExpression& expression, const std::map<std::string, std::string>& variables);
static std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
private:
static std::string processExpression(const Lepton::ExpressionTreeNode& node, const std::map<std::string, std::string>& variables);
static void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
static std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
};
} // namespace OpenMM
......
......@@ -33,6 +33,7 @@
#include "OpenCLExpressionUtilities.h"
#include "OpenCLIntegrationUtilities.h"
#include "OpenCLNonbondedUtilities.h"
#include "lepton/CustomFunction.h"
#include "lepton/Parser.h"
#include "lepton/ParsedExpression.h"
#include <cmath>
......@@ -231,13 +232,16 @@ void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const H
}
void OpenCLCalcHarmonicBondForceKernel::executeForces(ContextImpl& context) {
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numBonds);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
if (!hasInitializedKernel) {
hasInitializedKernel = true;
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numBonds);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
}
cl.executeKernel(kernel, numBonds);
}
......@@ -307,13 +311,16 @@ void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const
}
void OpenCLCalcHarmonicAngleForceKernel::executeForces(ContextImpl& context) {
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numAngles);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
if (!hasInitializedKernel) {
hasInitializedKernel = true;
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numAngles);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
}
cl.executeKernel(kernel, numAngles);
}
......@@ -384,13 +391,16 @@ void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, cons
}
void OpenCLCalcPeriodicTorsionForceKernel::executeForces(ContextImpl& context) {
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numTorsions);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
if (!hasInitializedKernel) {
hasInitializedKernel = true;
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numTorsions);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
}
cl.executeKernel(kernel, numTorsions);
}
......@@ -461,13 +471,16 @@ void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTo
}
void OpenCLCalcRBTorsionForceKernel::executeForces(ContextImpl& context) {
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numTorsions);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
if (!hasInitializedKernel) {
hasInitializedKernel = true;
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numTorsions);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
}
cl.executeKernel(kernel, numTorsions);
}
......@@ -639,27 +652,33 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
}
void OpenCLCalcNonbondedForceKernel::executeForces(ContextImpl& context) {
if (exceptionIndices != NULL) {
int numExceptions = exceptionIndices->getSize();
exceptionsKernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
exceptionsKernel.setArg<cl_int>(1, numExceptions);
exceptionsKernel.setArg<cl_float>(2, cutoffSquared);
exceptionsKernel.setArg<mm_float4>(3, cl.getNonbondedUtilities().getPeriodicBoxSize());
exceptionsKernel.setArg<cl::Buffer>(4, cl.getForceBuffers().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(5, cl.getEnergyBuffer().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(6, cl.getPosq().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(7, exceptionParams->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(8, exceptionIndices->getDeviceBuffer());
cl.executeKernel(exceptionsKernel, numExceptions);
if (!hasInitializedKernel) {
hasInitializedKernel = true;
if (exceptionIndices != NULL) {
int numExceptions = exceptionIndices->getSize();
exceptionsKernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
exceptionsKernel.setArg<cl_int>(1, numExceptions);
exceptionsKernel.setArg<cl_float>(2, cutoffSquared);
exceptionsKernel.setArg<mm_float4>(3, cl.getNonbondedUtilities().getPeriodicBoxSize());
exceptionsKernel.setArg<cl::Buffer>(4, cl.getForceBuffers().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(5, cl.getEnergyBuffer().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(6, cl.getPosq().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(7, exceptionParams->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(8, exceptionIndices->getDeviceBuffer());
}
if (cosSinSums != NULL) {
ewaldSumsKernel.setArg<cl::Buffer>(0, cl.getEnergyBuffer().getDeviceBuffer());
ewaldSumsKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
ewaldSumsKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer());
ewaldForcesKernel.setArg<cl::Buffer>(0, cl.getForceBuffers().getDeviceBuffer());
ewaldForcesKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
ewaldForcesKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer());
}
}
if (exceptionIndices != NULL)
cl.executeKernel(exceptionsKernel, exceptionIndices->getSize());
if (cosSinSums != NULL) {
ewaldSumsKernel.setArg<cl::Buffer>(0, cl.getEnergyBuffer().getDeviceBuffer());
ewaldSumsKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
ewaldSumsKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer());
cl.executeKernel(ewaldSumsKernel, cosSinSums->getSize());
ewaldForcesKernel.setArg<cl::Buffer>(0, cl.getForceBuffers().getDeviceBuffer());
ewaldForcesKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
ewaldForcesKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer());
cl.executeKernel(ewaldForcesKernel, cl.getNumAtoms());
}
}
......@@ -718,6 +737,10 @@ OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() {
delete exceptionParams;
if (exceptionIndices != NULL)
delete exceptionIndices;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
delete tabulatedFunctions[i];
}
void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, const CustomNonbondedForce& force) {
......@@ -746,9 +769,12 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
// Record parameters and exclusions.
int numParticles = force.getNumParticles();
string extraArguments;
params = new OpenCLArray<mm_float4>(cl, numParticles, "customNonbondedParameters");
if (force.getNumGlobalParameters() > 0)
if (force.getNumGlobalParameters() > 0) {
globals = new OpenCLArray<cl_float>(cl, force.getNumGlobalParameters(), "customNonbondedGlobals", false, CL_MEM_READ_ONLY);
extraArguments += ", __constant float* globals";
}
vector<mm_float4> paramVec(numParticles);
vector<vector<int> > exclusionList(numParticles);
for (int i = 0; i < numParticles; i++) {
......@@ -764,21 +790,80 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
paramVec[i].w = (cl_float) parameters[3];
exclusionList[i].push_back(i);
}
for (int i = 0; i < (int)exclusions.size(); i++) {
for (int i = 0; i < (int) exclusions.size(); i++) {
exclusionList[exclusions[i].first].push_back(exclusions[i].second);
exclusionList[exclusions[i].second].push_back(exclusions[i].first);
}
params->upload(paramVec);
// This class serves as a placeholder for custom functions in expressions.
class FunctionPlaceholder : public Lepton::CustomFunction {
public:
int getNumArguments() const {
return 1;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new FunctionPlaceholder();
}
};
// Record the tabulated functions.
FunctionPlaceholder* fp = new FunctionPlaceholder();
map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions;
vector<mm_float4> tabulatedFunctionParamsVec(force.getNumFunctions());
for (int i = 0; i < force.getNumFunctions(); i++) {
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
// gpuSetTabulatedFunction(gpu, i, name, values, min, max, interpolating);
string arrayName = prefix+"table"+intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = fp;
tabulatedFunctionParamsVec[i] = (mm_float4) {(float) min, (float) max, (float) ((values.size()-1)/(max-min)), 0.0f};
// First create a padded set of function values.
vector<double> padded(values.size()+2);
padded[0] = 2*values[0]-values[1];
for (int i = 0; i < (int) values.size(); i++)
padded[i+1] = values[i];
padded[padded.size()-1] = 2*values[values.size()-1]-values[values.size()-2];
// Now compute the spline coefficients.
vector<mm_float4> f(values.size()-1);
for (int i = 0; i < (int) values.size()-1; i++) {
if (interpolating)
f[i] = (mm_float4) {(cl_float) padded[i+1],
(cl_float) (0.5*(-padded[i]+padded[i+2])),
(cl_float) (0.5*(2.0*padded[i]-5.0*padded[i+1]+4.0*padded[i+2]-padded[i+3])),
(cl_float) (0.5*(-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3]))};
else
f[i] = (mm_float4) {(cl_float) ((padded[i]+4.0*padded[i+1]+padded[i+2])/6.0),
(cl_float) ((-3.0*padded[i]+3.0*padded[i+2])/6.0),
(cl_float) ((3.0*padded[i]-6.0*padded[i+1]+3.0*padded[i+2])/6.0),
(cl_float) ((-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3])/6.0)};
}
tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float4", sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
extraArguments += ", __constant float4* "+arrayName;
}
if (force.getNumFunctions() > 0) {
tabulatedFunctionParams = new OpenCLArray<mm_float4>(cl, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters", false, CL_MEM_READ_ONLY);
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(prefix+"functionParams", "float4", sizeof(cl_float4), tabulatedFunctionParams->getDeviceBuffer()));
extraArguments += ", __constant float4* "+prefix+"functionParams";
}
// Record information for the expressions.
......@@ -799,8 +884,11 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
globals->upload(globalParamValues);
bool useCutoff = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff);
bool usePeriodic = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && force.getNonbondedMethod() != CustomNonbondedForce::CutoffNonPeriodic);
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
map<string, Lepton::ParsedExpression> forceExpressions;
forceExpressions["tempEnergy += "] = energyExpression;
forceExpressions["tempForce -= "] = forceExpression;
// Create the kernels.
......@@ -824,13 +912,13 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
forceVariables[name] = prefix+value;
exceptionVariables[name] = value;
}
stringstream compute;
map<string, Lepton::ParsedExpression> paramExpressions;
for (int i = 0; i < force.getNumParameters(); i++) {
Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getParameterCombiningRule(i)).optimize();
compute << "float " << prefix << force.getParameterName(i) << " = " << OpenCLExpressionUtilities::createExpression(expression, paramVariables) << ";\n";
paramExpressions["float "+prefix+force.getParameterName(i)+" = " ] = Lepton::Parser::parse(force.getParameterCombiningRule(i)).optimize();
}
compute << "tempEnergy += " << OpenCLExpressionUtilities::createExpression(energyExpression, forceVariables) << ";\n";
compute << "tempForce -= " << OpenCLExpressionUtilities::createExpression(forceExpression, forceVariables) << ";\n";
stringstream compute;
compute << OpenCLExpressionUtilities::createExpressions(paramExpressions, paramVariables, functionDefinitions, prefix+"param_temp", prefix+"functionParams");
compute << OpenCLExpressionUtilities::createExpressions(forceExpressions, forceVariables, functionDefinitions, prefix+"force_temp", prefix+"functionParams");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
string source = cl.loadSourceFromFile("customNonbonded.cl", replacements);
......@@ -840,13 +928,20 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
globals->upload(globalParamValues);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(prefix+"globals", "float", sizeof(cl_float), globals->getDeviceBuffer()));
}
map<string, Lepton::ParsedExpression> exceptionExpressions;
stringstream computeExceptions;
computeExceptions << "energy += " << OpenCLExpressionUtilities::createExpression(energyExpression, exceptionVariables) << ";\n";
computeExceptions << "dEdR = " << OpenCLExpressionUtilities::createExpression(forceExpression, exceptionVariables) << ";\n";
exceptionExpressions["energy += "] = energyExpression;
exceptionExpressions["dEdR = "] = forceExpression;
computeExceptions << OpenCLExpressionUtilities::createExpressions(exceptionExpressions, exceptionVariables, functionDefinitions, "temp", prefix+"functionParams");
replacements["COMPUTE_FORCE"] = computeExceptions.str();
replacements["EXTRA_ARGUMENTS"] = extraArguments;
map<string, string> defines;
if (globals != NULL)
defines["HAS_GLOBALS"] = "1";
defines["CUTOFF_SQUARED"] = doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
Vec3 boxVectors[3];
system.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
defines["PERIODIC_BOX_SIZE_X"] = doubleToString(boxVectors[0][0]);
defines["PERIODIC_BOX_SIZE_Y"] = doubleToString(boxVectors[1][1]);
defines["PERIODIC_BOX_SIZE_Z"] = doubleToString(boxVectors[2][2]);
cl::Program program = cl.createProgram(cl.loadSourceFromFile("customNonbondedExceptions.cl", replacements), defines);
exceptionsKernel = cl::Kernel(program, "computeCustomNonbondedExceptions");
......@@ -880,23 +975,22 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
maxBuffers = max(maxBuffers, forceBufferCounter[i]);
}
cl.addForce(new OpenCLCustomNonbondedForceInfo(maxBuffers, force));
delete fp;
}
void OpenCLCalcCustomNonbondedForceKernel::executeForces(ContextImpl& context) {
if (exceptionParams != NULL) {
if (!hasCreatedKernels) {
hasCreatedKernels = true;
if (!hasInitializedKernel) {
hasInitializedKernel = true;
exceptionsKernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
exceptionsKernel.setArg<cl_int>(1, exceptionParams->getSize());
exceptionsKernel.setArg<cl_float>(2, cl.getNonbondedUtilities().getCutoffDistance()*cl.getNonbondedUtilities().getCutoffDistance());
exceptionsKernel.setArg<mm_float4>(3, cl.getNonbondedUtilities().getPeriodicBoxSize());
exceptionsKernel.setArg<cl::Buffer>(4, cl.getForceBuffers().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(5, cl.getEnergyBuffer().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(6, cl.getPosq().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(7, exceptionParams->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(8, exceptionIndices->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(5, exceptionParams->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(6, exceptionIndices->getDeviceBuffer());
if (globals != NULL)
exceptionsKernel.setArg<cl::Buffer>(9, globals->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(7, globals->getDeviceBuffer());
}
cl.executeKernel(exceptionsKernel, exceptionIndices->getSize());
}
......
......@@ -150,8 +150,8 @@ private:
*/
class OpenCLCalcHarmonicBondForceKernel : public CalcHarmonicBondForceKernel {
public:
OpenCLCalcHarmonicBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) :
CalcHarmonicBondForceKernel(name, platform), cl(cl), system(system), params(NULL), indices(NULL) {
OpenCLCalcHarmonicBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicBondForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system), params(NULL), indices(NULL) {
}
~OpenCLCalcHarmonicBondForceKernel();
/**
......@@ -176,6 +176,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
int numBonds;
bool hasInitializedKernel;
OpenCLContext& cl;
System& system;
OpenCLArray<mm_float2>* params;
......@@ -188,7 +189,8 @@ private:
*/
class OpenCLCalcHarmonicAngleForceKernel : public CalcHarmonicAngleForceKernel {
public:
OpenCLCalcHarmonicAngleForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicAngleForceKernel(name, platform), cl(cl), system(system) {
OpenCLCalcHarmonicAngleForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicAngleForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
}
~OpenCLCalcHarmonicAngleForceKernel();
/**
......@@ -213,6 +215,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
int numAngles;
bool hasInitializedKernel;
OpenCLContext& cl;
System& system;
OpenCLArray<mm_float2>* params;
......@@ -225,7 +228,8 @@ private:
*/
class OpenCLCalcPeriodicTorsionForceKernel : public CalcPeriodicTorsionForceKernel {
public:
OpenCLCalcPeriodicTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcPeriodicTorsionForceKernel(name, platform), cl(cl), system(system) {
OpenCLCalcPeriodicTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcPeriodicTorsionForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
}
~OpenCLCalcPeriodicTorsionForceKernel();
/**
......@@ -250,6 +254,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
int numTorsions;
bool hasInitializedKernel;
OpenCLContext& cl;
System& system;
OpenCLArray<mm_float4>* params;
......@@ -262,7 +267,8 @@ private:
*/
class OpenCLCalcRBTorsionForceKernel : public CalcRBTorsionForceKernel {
public:
OpenCLCalcRBTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcRBTorsionForceKernel(name, platform), cl(cl), system(system) {
OpenCLCalcRBTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcRBTorsionForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
}
~OpenCLCalcRBTorsionForceKernel();
/**
......@@ -287,6 +293,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
int numTorsions;
bool hasInitializedKernel;
OpenCLContext& cl;
System& system;
OpenCLArray<mm_float8>* params;
......@@ -299,8 +306,8 @@ private:
*/
class OpenCLCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public:
OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform), cl(cl),
sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL) {
OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL) {
}
~OpenCLCalcNonbondedForceKernel();
/**
......@@ -325,6 +332,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
OpenCLContext& cl;
bool hasInitializedKernel;
OpenCLArray<mm_float2>* sigmaEpsilon;
OpenCLArray<mm_float4>* exceptionParams;
OpenCLArray<mm_int4>* exceptionIndices;
......@@ -341,7 +349,7 @@ private:
class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public:
OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcCustomNonbondedForceKernel(name, platform),
hasCreatedKernels(false), cl(cl), params(NULL), globals(NULL), exceptionParams(NULL), exceptionIndices(NULL), system(system) {
hasInitializedKernel(false), cl(cl), params(NULL), globals(NULL), exceptionParams(NULL), exceptionIndices(NULL), tabulatedFunctionParams(NULL), system(system) {
}
~OpenCLCalcCustomNonbondedForceKernel();
/**
......@@ -365,15 +373,17 @@ public:
*/
double executeEnergy(ContextImpl& context);
private:
bool hasCreatedKernels;
bool hasInitializedKernel;
OpenCLContext& cl;
OpenCLArray<mm_float4>* params;
OpenCLArray<cl_float>* globals;
OpenCLArray<mm_float4>* exceptionParams;
OpenCLArray<mm_int4>* exceptionIndices;
OpenCLArray<mm_float4>* tabulatedFunctionParams;
cl::Kernel exceptionsKernel;
std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray<mm_float4>*> tabulatedFunctions;
System& system;
};
......
......@@ -2,13 +2,9 @@
* Compute custom nonbonded exceptions.
*/
__kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, float cutoffSquared, float4 periodicBoxSize, __global float4* forceBuffers, __global float* energyBuffer,
__kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, __global float4* forceBuffers, __global float* energyBuffer,
__global float4* posq, __global float4* params, __global int4* indices
#ifdef HAS_GLOBALS
, __constant float* globals) {
#else
) {
#endif
EXTRA_ARGUMENTS) {
int index = get_global_id(0);
float energy = 0.0f;
while (index < numExceptions) {
......@@ -18,15 +14,15 @@ __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions,
float4 exceptionParams = params[index];
float4 delta = posq[atoms.y]-posq[atoms.x];
#ifdef USE_PERIODIC
delta.x -= floor(delta.x/periodicBoxSize.x+0.5f)*periodicBoxSize.x;
delta.y -= floor(delta.y/periodicBoxSize.y+0.5f)*periodicBoxSize.y;
delta.z -= floor(delta.z/periodicBoxSize.z+0.5f)*periodicBoxSize.z;
delta.x -= floor(delta.x/PERIODIC_BOX_SIZE_X+0.5f)*PERIODIC_BOX_SIZE_X;
delta.y -= floor(delta.y/PERIODIC_BOX_SIZE_Y+0.5f)*PERIODIC_BOX_SIZE_Y;
delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z;
#endif
// Compute the force.
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
if (r2 > cutoffSquared) {
if (r2 > CUTOFF_SQUARED) {
#else
{
#endif
......
......@@ -241,8 +241,8 @@ int main() {
testExceptions();
testCutoff();
testPeriodic();
// testTabulatedFunction(true);
// testTabulatedFunction(false);
testTabulatedFunction(true);
testTabulatedFunction(false);
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment