Commit 72d59cbe authored by Peter Eastman's avatar Peter Eastman
Browse files

Finished OpenCL implementation of CustomNonbondedForce. Also implemented a few optimizations.

parent 2127b8dd
...@@ -84,6 +84,8 @@ public: ...@@ -84,6 +84,8 @@ public:
ExpressionTreeNode(const ExpressionTreeNode& node); ExpressionTreeNode(const ExpressionTreeNode& node);
ExpressionTreeNode(); ExpressionTreeNode();
~ExpressionTreeNode(); ~ExpressionTreeNode();
bool operator==(const ExpressionTreeNode& node) const;
bool operator!=(const ExpressionTreeNode& node) const;
ExpressionTreeNode& operator=(const ExpressionTreeNode& node); ExpressionTreeNode& operator=(const ExpressionTreeNode& node);
/** /**
* Get the Operation performed by this node. * Get the Operation performed by this node.
......
...@@ -95,6 +95,12 @@ public: ...@@ -95,6 +95,12 @@ public:
* @param variable the variable with respect to which the derivate should be taken * @param variable the variable with respect to which the derivate should be taken
*/ */
virtual ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const = 0; virtual ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const = 0;
virtual bool operator!=(const Operation& op) const {
return op.getId() != getId();
}
virtual bool operator==(const Operation& op) const {
return !(*this != op);
}
class Constant; class Constant;
class Variable; class Variable;
class Custom; class Custom;
...@@ -149,6 +155,10 @@ public: ...@@ -149,6 +155,10 @@ public:
double getValue() const { double getValue() const {
return value; return value;
} }
bool operator!=(const Operation& op) const {
const Constant* o = dynamic_cast<const Constant*>(&op);
return (o == NULL || o->value != value);
}
private: private:
double value; double value;
}; };
...@@ -176,6 +186,10 @@ public: ...@@ -176,6 +186,10 @@ public:
return iter->second; return iter->second;
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
bool operator!=(const Operation& op) const {
const Variable* o = dynamic_cast<const Variable*>(&op);
return (o == NULL || o->name != name);
}
private: private:
std::string name; std::string name;
}; };
...@@ -214,6 +228,10 @@ public: ...@@ -214,6 +228,10 @@ public:
const std::vector<int>& getDerivOrder() const { const std::vector<int>& getDerivOrder() const {
return derivOrder; return derivOrder;
} }
bool operator!=(const Operation& op) const {
const Custom* o = dynamic_cast<const Custom*>(&op);
return (o == NULL || o->name != name || o->isDerivative != isDerivative || o->derivOrder != derivOrder);
}
private: private:
std::string name; std::string name;
CustomFunction* function; CustomFunction* function;
...@@ -708,6 +726,10 @@ public: ...@@ -708,6 +726,10 @@ public:
double getValue() const { double getValue() const {
return value; return value;
} }
bool operator!=(const Operation& op) const {
const AddConstant* o = dynamic_cast<const AddConstant*>(&op);
return (o == NULL || o->value != value);
}
private: private:
double value; double value;
}; };
...@@ -737,6 +759,10 @@ public: ...@@ -737,6 +759,10 @@ public:
double getValue() const { double getValue() const {
return value; return value;
} }
bool operator!=(const Operation& op) const {
const MultiplyConstant* o = dynamic_cast<const MultiplyConstant*>(&op);
return (o == NULL || o->value != value);
}
private: private:
double value; double value;
}; };
...@@ -766,6 +792,10 @@ public: ...@@ -766,6 +792,10 @@ public:
double getValue() const { double getValue() const {
return value; return value;
} }
bool operator!=(const Operation& op) const {
const PowerConstant* o = dynamic_cast<const PowerConstant*>(&op);
return (o == NULL || o->value != value);
}
private: private:
double value; double value;
}; };
......
...@@ -48,6 +48,11 @@ class ExpressionProgram; ...@@ -48,6 +48,11 @@ class ExpressionProgram;
class LEPTON_EXPORT ParsedExpression { class LEPTON_EXPORT ParsedExpression {
public: public:
/**
* Create an uninitialized ParsedExpression. This exists so that ParsedExpressions can be put in STL containers.
* Doing anything with it will produce an exception.
*/
ParsedExpression();
/** /**
* Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class * Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class
* to parse expression. * to parse expression.
......
...@@ -70,6 +70,19 @@ ExpressionTreeNode::~ExpressionTreeNode() { ...@@ -70,6 +70,19 @@ ExpressionTreeNode::~ExpressionTreeNode() {
delete operation; delete operation;
} }
bool ExpressionTreeNode::operator!=(const ExpressionTreeNode& node) const {
if (node.getOperation() != getOperation())
return true;
for (int i = 0; i < (int) getChildren().size(); i++)
if (getChildren()[i] != node.getChildren()[i])
return true;
return false;
}
bool ExpressionTreeNode::operator==(const ExpressionTreeNode& node) const {
return !(*this != node);
}
ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node) { ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node) {
if (operation != NULL) if (operation != NULL)
delete operation; delete operation;
......
...@@ -38,10 +38,15 @@ ...@@ -38,10 +38,15 @@
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
ParsedExpression::ParsedExpression() : rootNode(ExpressionTreeNode()) {
}
ParsedExpression::ParsedExpression(const ExpressionTreeNode& rootNode) : rootNode(rootNode) { ParsedExpression::ParsedExpression(const ExpressionTreeNode& rootNode) : rootNode(rootNode) {
} }
const ExpressionTreeNode& ParsedExpression::getRootNode() const { const ExpressionTreeNode& ParsedExpression::getRootNode() const {
if (&rootNode.getOperation() == NULL)
throw Exception("Illegal call to an initialized ParsedExpression");
return rootNode; return rootNode;
} }
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "OpenCLExpressionUtilities.h" #include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
#include <sstream>
using namespace OpenMM; using namespace OpenMM;
using namespace Lepton; using namespace Lepton;
...@@ -46,69 +45,152 @@ static string intToString(int value) { ...@@ -46,69 +45,152 @@ static string intToString(int value) {
return s.str(); return s.str();
} }
string OpenCLExpressionUtilities::createExpression(const ParsedExpression& expression, const map<string, string>& variables) { string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
return processExpression(expression.getRootNode(), variables); const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
stringstream out;
vector<pair<ExpressionTreeNode, string> > temps;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, variables, functions, prefix, functionParams);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
} }
string OpenCLExpressionUtilities::processExpression(const ExpressionTreeNode& node, const map<string, string>& variables) { void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const map<string, string>& variables, const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, variables, functions, prefix, functionParams);
string name = prefix+intToString(temps.size());
out << "float " << name << " = ";
switch (node.getOperation().getId()) { switch (node.getOperation().getId()) {
case Operation::CONSTANT: case Operation::CONSTANT:
return doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue()); out << doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
break;
case Operation::VARIABLE: case Operation::VARIABLE:
{ {
map<string, string>::const_iterator iter = variables.find(node.getOperation().getName()); map<string, string>::const_iterator iter = variables.find(node.getOperation().getName());
if (iter == variables.end()) if (iter == variables.end())
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
return iter->second; out << iter->second;
break;
}
case Operation::CUSTOM:
{
int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++)
;
if (i == functions.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
out << "0.0f;\n";
out << "{\n";
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "int index = (int) (floor((x-params.x)*params.z));\n";
out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "x = (x-params.x)*params.z-index;\n";
if (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 0)
out << name << " = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));\n";
else
out << name << " = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;\n";
out << "}\n";
out << "}";
break;
} }
case Operation::ADD: case Operation::ADD:
return "("+processExpression(node.getChildren()[0], variables)+")+("+processExpression(node.getChildren()[1], variables)+")"; out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
break;
case Operation::SUBTRACT: case Operation::SUBTRACT:
return "("+processExpression(node.getChildren()[0], variables)+")-("+processExpression(node.getChildren()[1], variables)+")"; out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
break;
case Operation::MULTIPLY: case Operation::MULTIPLY:
return "("+processExpression(node.getChildren()[0], variables)+")*("+processExpression(node.getChildren()[1], variables)+")"; out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
break;
case Operation::DIVIDE: case Operation::DIVIDE:
return "("+processExpression(node.getChildren()[0], variables)+")/("+processExpression(node.getChildren()[1], variables)+")"; out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
break;
case Operation::POWER: case Operation::POWER:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), ("+processExpression(node.getChildren()[1], variables)+"))"; out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::NEGATE: case Operation::NEGATE:
return "-("+processExpression(node.getChildren()[0], variables)+")"; out << "-" << getTempName(node.getChildren()[0], temps);
break;
case Operation::SQRT: case Operation::SQRT:
return "sqrt("+processExpression(node.getChildren()[0], variables)+")"; out << "sqrt(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::EXP: case Operation::EXP:
return "exp("+processExpression(node.getChildren()[0], variables)+")"; out << "exp(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::LOG: case Operation::LOG:
return "log("+processExpression(node.getChildren()[0], variables)+")"; out << "log(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SIN: case Operation::SIN:
return "sin("+processExpression(node.getChildren()[0], variables)+")"; out << "sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COS: case Operation::COS:
return "cos("+processExpression(node.getChildren()[0], variables)+")"; out << "cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SEC: case Operation::SEC:
return "1.0f/cos("+processExpression(node.getChildren()[0], variables)+")"; out << "1.0f/cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::CSC: case Operation::CSC:
return "1.0f/sin("+processExpression(node.getChildren()[0], variables)+")"; out << "1.0f/sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::TAN: case Operation::TAN:
return "tan("+processExpression(node.getChildren()[0], variables)+")"; out << "tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COT: case Operation::COT:
return "1.0f/tan("+processExpression(node.getChildren()[0], variables)+")"; out << "1.0f/tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ASIN: case Operation::ASIN:
return "asin("+processExpression(node.getChildren()[0], variables)+")"; out << "asin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ACOS: case Operation::ACOS:
return "acos("+processExpression(node.getChildren()[0], variables)+")"; out << "acos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ATAN: case Operation::ATAN:
return "atan("+processExpression(node.getChildren()[0], variables)+")"; out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SQUARE: case Operation::SQUARE:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), 2.0f)"; {
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg;
break;
}
case Operation::CUBE: case Operation::CUBE:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), 3.0f)"; {
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg << "*" << arg;
break;
}
case Operation::RECIPROCAL: case Operation::RECIPROCAL:
return "1.0f/("+processExpression(node.getChildren()[0], variables)+")"; out << "1.0f/" << getTempName(node.getChildren()[0], temps);
break;
case Operation::ADD_CONSTANT: case Operation::ADD_CONSTANT:
return doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue())+"+("+processExpression(node.getChildren()[0], variables)+")"; out << doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
break;
case Operation::MULTIPLY_CONSTANT: case Operation::MULTIPLY_CONSTANT:
return doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue())+"*("+processExpression(node.getChildren()[0], variables)+")"; out << doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
break;
case Operation::POWER_CONSTANT: case Operation::POWER_CONSTANT:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), "+doubleToString(dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue())+")"; out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << doubleToString(dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue()) << ")";
} break;
default:
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName()); throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
}
out << ";\n";
temps.push_back(make_pair(node, name));
}
string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return temps[i].second;
stringstream out;
out << "Internal error: No temporary variable for expression node: " << node;
throw OpenMMException(out.str());
} }
...@@ -30,7 +30,9 @@ ...@@ -30,7 +30,9 @@
#include "lepton/ExpressionTreeNode.h" #include "lepton/ExpressionTreeNode.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include <map> #include <map>
#include <sstream>
#include <string> #include <string>
#include <utility>
namespace OpenMM { namespace OpenMM {
...@@ -41,9 +43,13 @@ namespace OpenMM { ...@@ -41,9 +43,13 @@ namespace OpenMM {
class OpenCLExpressionUtilities { class OpenCLExpressionUtilities {
public: public:
static std::string createExpression(const Lepton::ParsedExpression& expression, const std::map<std::string, std::string>& variables); static std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
private: private:
static std::string processExpression(const Lepton::ExpressionTreeNode& node, const std::map<std::string, std::string>& variables); static void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
static std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include "OpenCLExpressionUtilities.h" #include "OpenCLExpressionUtilities.h"
#include "OpenCLIntegrationUtilities.h" #include "OpenCLIntegrationUtilities.h"
#include "OpenCLNonbondedUtilities.h" #include "OpenCLNonbondedUtilities.h"
#include "lepton/CustomFunction.h"
#include "lepton/Parser.h" #include "lepton/Parser.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include <cmath> #include <cmath>
...@@ -231,6 +232,8 @@ void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const H ...@@ -231,6 +232,8 @@ void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const H
} }
void OpenCLCalcHarmonicBondForceKernel::executeForces(ContextImpl& context) { void OpenCLCalcHarmonicBondForceKernel::executeForces(ContextImpl& context) {
if (!hasInitializedKernel) {
hasInitializedKernel = true;
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms()); kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numBonds); kernel.setArg<cl_int>(1, numBonds);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer()); kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
...@@ -238,6 +241,7 @@ void OpenCLCalcHarmonicBondForceKernel::executeForces(ContextImpl& context) { ...@@ -238,6 +241,7 @@ void OpenCLCalcHarmonicBondForceKernel::executeForces(ContextImpl& context) {
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer()); kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer()); kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer()); kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
}
cl.executeKernel(kernel, numBonds); cl.executeKernel(kernel, numBonds);
} }
...@@ -307,6 +311,8 @@ void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const ...@@ -307,6 +311,8 @@ void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const
} }
void OpenCLCalcHarmonicAngleForceKernel::executeForces(ContextImpl& context) { void OpenCLCalcHarmonicAngleForceKernel::executeForces(ContextImpl& context) {
if (!hasInitializedKernel) {
hasInitializedKernel = true;
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms()); kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numAngles); kernel.setArg<cl_int>(1, numAngles);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer()); kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
...@@ -314,6 +320,7 @@ void OpenCLCalcHarmonicAngleForceKernel::executeForces(ContextImpl& context) { ...@@ -314,6 +320,7 @@ void OpenCLCalcHarmonicAngleForceKernel::executeForces(ContextImpl& context) {
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer()); kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer()); kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer()); kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
}
cl.executeKernel(kernel, numAngles); cl.executeKernel(kernel, numAngles);
} }
...@@ -384,6 +391,8 @@ void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, cons ...@@ -384,6 +391,8 @@ void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, cons
} }
void OpenCLCalcPeriodicTorsionForceKernel::executeForces(ContextImpl& context) { void OpenCLCalcPeriodicTorsionForceKernel::executeForces(ContextImpl& context) {
if (!hasInitializedKernel) {
hasInitializedKernel = true;
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms()); kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numTorsions); kernel.setArg<cl_int>(1, numTorsions);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer()); kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
...@@ -391,6 +400,7 @@ void OpenCLCalcPeriodicTorsionForceKernel::executeForces(ContextImpl& context) { ...@@ -391,6 +400,7 @@ void OpenCLCalcPeriodicTorsionForceKernel::executeForces(ContextImpl& context) {
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer()); kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer()); kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer()); kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
}
cl.executeKernel(kernel, numTorsions); cl.executeKernel(kernel, numTorsions);
} }
...@@ -461,6 +471,8 @@ void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTo ...@@ -461,6 +471,8 @@ void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTo
} }
void OpenCLCalcRBTorsionForceKernel::executeForces(ContextImpl& context) { void OpenCLCalcRBTorsionForceKernel::executeForces(ContextImpl& context) {
if (!hasInitializedKernel) {
hasInitializedKernel = true;
kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms()); kernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
kernel.setArg<cl_int>(1, numTorsions); kernel.setArg<cl_int>(1, numTorsions);
kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer()); kernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
...@@ -468,6 +480,7 @@ void OpenCLCalcRBTorsionForceKernel::executeForces(ContextImpl& context) { ...@@ -468,6 +480,7 @@ void OpenCLCalcRBTorsionForceKernel::executeForces(ContextImpl& context) {
kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer()); kernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer()); kernel.setArg<cl::Buffer>(5, params->getDeviceBuffer());
kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer()); kernel.setArg<cl::Buffer>(6, indices->getDeviceBuffer());
}
cl.executeKernel(kernel, numTorsions); cl.executeKernel(kernel, numTorsions);
} }
...@@ -639,6 +652,8 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb ...@@ -639,6 +652,8 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
} }
void OpenCLCalcNonbondedForceKernel::executeForces(ContextImpl& context) { void OpenCLCalcNonbondedForceKernel::executeForces(ContextImpl& context) {
if (!hasInitializedKernel) {
hasInitializedKernel = true;
if (exceptionIndices != NULL) { if (exceptionIndices != NULL) {
int numExceptions = exceptionIndices->getSize(); int numExceptions = exceptionIndices->getSize();
exceptionsKernel.setArg<cl_int>(0, cl.getPaddedNumAtoms()); exceptionsKernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
...@@ -650,16 +665,20 @@ void OpenCLCalcNonbondedForceKernel::executeForces(ContextImpl& context) { ...@@ -650,16 +665,20 @@ void OpenCLCalcNonbondedForceKernel::executeForces(ContextImpl& context) {
exceptionsKernel.setArg<cl::Buffer>(6, cl.getPosq().getDeviceBuffer()); exceptionsKernel.setArg<cl::Buffer>(6, cl.getPosq().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(7, exceptionParams->getDeviceBuffer()); exceptionsKernel.setArg<cl::Buffer>(7, exceptionParams->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(8, exceptionIndices->getDeviceBuffer()); exceptionsKernel.setArg<cl::Buffer>(8, exceptionIndices->getDeviceBuffer());
cl.executeKernel(exceptionsKernel, numExceptions);
} }
if (cosSinSums != NULL) { if (cosSinSums != NULL) {
ewaldSumsKernel.setArg<cl::Buffer>(0, cl.getEnergyBuffer().getDeviceBuffer()); ewaldSumsKernel.setArg<cl::Buffer>(0, cl.getEnergyBuffer().getDeviceBuffer());
ewaldSumsKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer()); ewaldSumsKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
ewaldSumsKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer()); ewaldSumsKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer());
cl.executeKernel(ewaldSumsKernel, cosSinSums->getSize());
ewaldForcesKernel.setArg<cl::Buffer>(0, cl.getForceBuffers().getDeviceBuffer()); ewaldForcesKernel.setArg<cl::Buffer>(0, cl.getForceBuffers().getDeviceBuffer());
ewaldForcesKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer()); ewaldForcesKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
ewaldForcesKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer()); ewaldForcesKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer());
}
}
if (exceptionIndices != NULL)
cl.executeKernel(exceptionsKernel, exceptionIndices->getSize());
if (cosSinSums != NULL) {
cl.executeKernel(ewaldSumsKernel, cosSinSums->getSize());
cl.executeKernel(ewaldForcesKernel, cl.getNumAtoms()); cl.executeKernel(ewaldForcesKernel, cl.getNumAtoms());
} }
} }
...@@ -718,6 +737,10 @@ OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() { ...@@ -718,6 +737,10 @@ OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() {
delete exceptionParams; delete exceptionParams;
if (exceptionIndices != NULL) if (exceptionIndices != NULL)
delete exceptionIndices; delete exceptionIndices;
if (tabulatedFunctionParams != NULL)
delete tabulatedFunctionParams;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
delete tabulatedFunctions[i];
} }
void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, const CustomNonbondedForce& force) { void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, const CustomNonbondedForce& force) {
...@@ -746,9 +769,12 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -746,9 +769,12 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
// Record parameters and exclusions. // Record parameters and exclusions.
int numParticles = force.getNumParticles(); int numParticles = force.getNumParticles();
string extraArguments;
params = new OpenCLArray<mm_float4>(cl, numParticles, "customNonbondedParameters"); params = new OpenCLArray<mm_float4>(cl, numParticles, "customNonbondedParameters");
if (force.getNumGlobalParameters() > 0) if (force.getNumGlobalParameters() > 0) {
globals = new OpenCLArray<cl_float>(cl, force.getNumGlobalParameters(), "customNonbondedGlobals", false, CL_MEM_READ_ONLY); globals = new OpenCLArray<cl_float>(cl, force.getNumGlobalParameters(), "customNonbondedGlobals", false, CL_MEM_READ_ONLY);
extraArguments += ", __constant float* globals";
}
vector<mm_float4> paramVec(numParticles); vector<mm_float4> paramVec(numParticles);
vector<vector<int> > exclusionList(numParticles); vector<vector<int> > exclusionList(numParticles);
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
...@@ -764,21 +790,80 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -764,21 +790,80 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
paramVec[i].w = (cl_float) parameters[3]; paramVec[i].w = (cl_float) parameters[3];
exclusionList[i].push_back(i); exclusionList[i].push_back(i);
} }
for (int i = 0; i < (int)exclusions.size(); i++) { for (int i = 0; i < (int) exclusions.size(); i++) {
exclusionList[exclusions[i].first].push_back(exclusions[i].second); exclusionList[exclusions[i].first].push_back(exclusions[i].second);
exclusionList[exclusions[i].second].push_back(exclusions[i].first); exclusionList[exclusions[i].second].push_back(exclusions[i].first);
} }
params->upload(paramVec); params->upload(paramVec);
// This class serves as a placeholder for custom functions in expressions.
class FunctionPlaceholder : public Lepton::CustomFunction {
public:
int getNumArguments() const {
return 1;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new FunctionPlaceholder();
}
};
// Record the tabulated functions. // Record the tabulated functions.
FunctionPlaceholder* fp = new FunctionPlaceholder();
map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions;
vector<mm_float4> tabulatedFunctionParamsVec(force.getNumFunctions());
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating); force.getFunctionParameters(i, name, values, min, max, interpolating);
// gpuSetTabulatedFunction(gpu, i, name, values, min, max, interpolating); string arrayName = prefix+"table"+intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = fp;
tabulatedFunctionParamsVec[i] = (mm_float4) {(float) min, (float) max, (float) ((values.size()-1)/(max-min)), 0.0f};
// First create a padded set of function values.
vector<double> padded(values.size()+2);
padded[0] = 2*values[0]-values[1];
for (int i = 0; i < (int) values.size(); i++)
padded[i+1] = values[i];
padded[padded.size()-1] = 2*values[values.size()-1]-values[values.size()-2];
// Now compute the spline coefficients.
vector<mm_float4> f(values.size()-1);
for (int i = 0; i < (int) values.size()-1; i++) {
if (interpolating)
f[i] = (mm_float4) {(cl_float) padded[i+1],
(cl_float) (0.5*(-padded[i]+padded[i+2])),
(cl_float) (0.5*(2.0*padded[i]-5.0*padded[i+1]+4.0*padded[i+2]-padded[i+3])),
(cl_float) (0.5*(-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3]))};
else
f[i] = (mm_float4) {(cl_float) ((padded[i]+4.0*padded[i+1]+padded[i+2])/6.0),
(cl_float) ((-3.0*padded[i]+3.0*padded[i+2])/6.0),
(cl_float) ((3.0*padded[i]-6.0*padded[i+1]+3.0*padded[i+2])/6.0),
(cl_float) ((-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3])/6.0)};
}
tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float4", sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
extraArguments += ", __constant float4* "+arrayName;
}
if (force.getNumFunctions() > 0) {
tabulatedFunctionParams = new OpenCLArray<mm_float4>(cl, tabulatedFunctionParamsVec.size(), "tabulatedFunctionParameters", false, CL_MEM_READ_ONLY);
tabulatedFunctionParams->upload(tabulatedFunctionParamsVec);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(prefix+"functionParams", "float4", sizeof(cl_float4), tabulatedFunctionParams->getDeviceBuffer()));
extraArguments += ", __constant float4* "+prefix+"functionParams";
} }
// Record information for the expressions. // Record information for the expressions.
...@@ -799,8 +884,11 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -799,8 +884,11 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
globals->upload(globalParamValues); globals->upload(globalParamValues);
bool useCutoff = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff); bool useCutoff = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff);
bool usePeriodic = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && force.getNonbondedMethod() != CustomNonbondedForce::CutoffNonPeriodic); bool usePeriodic = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && force.getNonbondedMethod() != CustomNonbondedForce::CutoffNonPeriodic);
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize(); Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize(); Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
map<string, Lepton::ParsedExpression> forceExpressions;
forceExpressions["tempEnergy += "] = energyExpression;
forceExpressions["tempForce -= "] = forceExpression;
// Create the kernels. // Create the kernels.
...@@ -824,13 +912,13 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -824,13 +912,13 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
forceVariables[name] = prefix+value; forceVariables[name] = prefix+value;
exceptionVariables[name] = value; exceptionVariables[name] = value;
} }
stringstream compute; map<string, Lepton::ParsedExpression> paramExpressions;
for (int i = 0; i < force.getNumParameters(); i++) { for (int i = 0; i < force.getNumParameters(); i++) {
Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getParameterCombiningRule(i)).optimize(); paramExpressions["float "+prefix+force.getParameterName(i)+" = " ] = Lepton::Parser::parse(force.getParameterCombiningRule(i)).optimize();
compute << "float " << prefix << force.getParameterName(i) << " = " << OpenCLExpressionUtilities::createExpression(expression, paramVariables) << ";\n";
} }
compute << "tempEnergy += " << OpenCLExpressionUtilities::createExpression(energyExpression, forceVariables) << ";\n"; stringstream compute;
compute << "tempForce -= " << OpenCLExpressionUtilities::createExpression(forceExpression, forceVariables) << ";\n"; compute << OpenCLExpressionUtilities::createExpressions(paramExpressions, paramVariables, functionDefinitions, prefix+"param_temp", prefix+"functionParams");
compute << OpenCLExpressionUtilities::createExpressions(forceExpressions, forceVariables, functionDefinitions, prefix+"force_temp", prefix+"functionParams");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
string source = cl.loadSourceFromFile("customNonbonded.cl", replacements); string source = cl.loadSourceFromFile("customNonbonded.cl", replacements);
...@@ -840,13 +928,20 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -840,13 +928,20 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
globals->upload(globalParamValues); globals->upload(globalParamValues);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(prefix+"globals", "float", sizeof(cl_float), globals->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(prefix+"globals", "float", sizeof(cl_float), globals->getDeviceBuffer()));
} }
map<string, Lepton::ParsedExpression> exceptionExpressions;
stringstream computeExceptions; stringstream computeExceptions;
computeExceptions << "energy += " << OpenCLExpressionUtilities::createExpression(energyExpression, exceptionVariables) << ";\n"; exceptionExpressions["energy += "] = energyExpression;
computeExceptions << "dEdR = " << OpenCLExpressionUtilities::createExpression(forceExpression, exceptionVariables) << ";\n"; exceptionExpressions["dEdR = "] = forceExpression;
computeExceptions << OpenCLExpressionUtilities::createExpressions(exceptionExpressions, exceptionVariables, functionDefinitions, "temp", prefix+"functionParams");
replacements["COMPUTE_FORCE"] = computeExceptions.str(); replacements["COMPUTE_FORCE"] = computeExceptions.str();
replacements["EXTRA_ARGUMENTS"] = extraArguments;
map<string, string> defines; map<string, string> defines;
if (globals != NULL) defines["CUTOFF_SQUARED"] = doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
defines["HAS_GLOBALS"] = "1"; Vec3 boxVectors[3];
system.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
defines["PERIODIC_BOX_SIZE_X"] = doubleToString(boxVectors[0][0]);
defines["PERIODIC_BOX_SIZE_Y"] = doubleToString(boxVectors[1][1]);
defines["PERIODIC_BOX_SIZE_Z"] = doubleToString(boxVectors[2][2]);
cl::Program program = cl.createProgram(cl.loadSourceFromFile("customNonbondedExceptions.cl", replacements), defines); cl::Program program = cl.createProgram(cl.loadSourceFromFile("customNonbondedExceptions.cl", replacements), defines);
exceptionsKernel = cl::Kernel(program, "computeCustomNonbondedExceptions"); exceptionsKernel = cl::Kernel(program, "computeCustomNonbondedExceptions");
...@@ -880,23 +975,22 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -880,23 +975,22 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
maxBuffers = max(maxBuffers, forceBufferCounter[i]); maxBuffers = max(maxBuffers, forceBufferCounter[i]);
} }
cl.addForce(new OpenCLCustomNonbondedForceInfo(maxBuffers, force)); cl.addForce(new OpenCLCustomNonbondedForceInfo(maxBuffers, force));
delete fp;
} }
void OpenCLCalcCustomNonbondedForceKernel::executeForces(ContextImpl& context) { void OpenCLCalcCustomNonbondedForceKernel::executeForces(ContextImpl& context) {
if (exceptionParams != NULL) { if (exceptionParams != NULL) {
if (!hasCreatedKernels) { if (!hasInitializedKernel) {
hasCreatedKernels = true; hasInitializedKernel = true;
exceptionsKernel.setArg<cl_int>(0, cl.getPaddedNumAtoms()); exceptionsKernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
exceptionsKernel.setArg<cl_int>(1, exceptionParams->getSize()); exceptionsKernel.setArg<cl_int>(1, exceptionParams->getSize());
exceptionsKernel.setArg<cl_float>(2, cl.getNonbondedUtilities().getCutoffDistance()*cl.getNonbondedUtilities().getCutoffDistance()); exceptionsKernel.setArg<cl::Buffer>(2, cl.getForceBuffers().getDeviceBuffer());
exceptionsKernel.setArg<mm_float4>(3, cl.getNonbondedUtilities().getPeriodicBoxSize()); exceptionsKernel.setArg<cl::Buffer>(3, cl.getEnergyBuffer().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(4, cl.getForceBuffers().getDeviceBuffer()); exceptionsKernel.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(5, cl.getEnergyBuffer().getDeviceBuffer()); exceptionsKernel.setArg<cl::Buffer>(5, exceptionParams->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(6, cl.getPosq().getDeviceBuffer()); exceptionsKernel.setArg<cl::Buffer>(6, exceptionIndices->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(7, exceptionParams->getDeviceBuffer());
exceptionsKernel.setArg<cl::Buffer>(8, exceptionIndices->getDeviceBuffer());
if (globals != NULL) if (globals != NULL)
exceptionsKernel.setArg<cl::Buffer>(9, globals->getDeviceBuffer()); exceptionsKernel.setArg<cl::Buffer>(7, globals->getDeviceBuffer());
} }
cl.executeKernel(exceptionsKernel, exceptionIndices->getSize()); cl.executeKernel(exceptionsKernel, exceptionIndices->getSize());
} }
......
...@@ -150,8 +150,8 @@ private: ...@@ -150,8 +150,8 @@ private:
*/ */
class OpenCLCalcHarmonicBondForceKernel : public CalcHarmonicBondForceKernel { class OpenCLCalcHarmonicBondForceKernel : public CalcHarmonicBondForceKernel {
public: public:
OpenCLCalcHarmonicBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : OpenCLCalcHarmonicBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicBondForceKernel(name, platform),
CalcHarmonicBondForceKernel(name, platform), cl(cl), system(system), params(NULL), indices(NULL) { hasInitializedKernel(false), cl(cl), system(system), params(NULL), indices(NULL) {
} }
~OpenCLCalcHarmonicBondForceKernel(); ~OpenCLCalcHarmonicBondForceKernel();
/** /**
...@@ -176,6 +176,7 @@ public: ...@@ -176,6 +176,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
int numBonds; int numBonds;
bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
System& system; System& system;
OpenCLArray<mm_float2>* params; OpenCLArray<mm_float2>* params;
...@@ -188,7 +189,8 @@ private: ...@@ -188,7 +189,8 @@ private:
*/ */
class OpenCLCalcHarmonicAngleForceKernel : public CalcHarmonicAngleForceKernel { class OpenCLCalcHarmonicAngleForceKernel : public CalcHarmonicAngleForceKernel {
public: public:
OpenCLCalcHarmonicAngleForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicAngleForceKernel(name, platform), cl(cl), system(system) { OpenCLCalcHarmonicAngleForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicAngleForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
} }
~OpenCLCalcHarmonicAngleForceKernel(); ~OpenCLCalcHarmonicAngleForceKernel();
/** /**
...@@ -213,6 +215,7 @@ public: ...@@ -213,6 +215,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
int numAngles; int numAngles;
bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
System& system; System& system;
OpenCLArray<mm_float2>* params; OpenCLArray<mm_float2>* params;
...@@ -225,7 +228,8 @@ private: ...@@ -225,7 +228,8 @@ private:
*/ */
class OpenCLCalcPeriodicTorsionForceKernel : public CalcPeriodicTorsionForceKernel { class OpenCLCalcPeriodicTorsionForceKernel : public CalcPeriodicTorsionForceKernel {
public: public:
OpenCLCalcPeriodicTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcPeriodicTorsionForceKernel(name, platform), cl(cl), system(system) { OpenCLCalcPeriodicTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcPeriodicTorsionForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
} }
~OpenCLCalcPeriodicTorsionForceKernel(); ~OpenCLCalcPeriodicTorsionForceKernel();
/** /**
...@@ -250,6 +254,7 @@ public: ...@@ -250,6 +254,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
int numTorsions; int numTorsions;
bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
System& system; System& system;
OpenCLArray<mm_float4>* params; OpenCLArray<mm_float4>* params;
...@@ -262,7 +267,8 @@ private: ...@@ -262,7 +267,8 @@ private:
*/ */
class OpenCLCalcRBTorsionForceKernel : public CalcRBTorsionForceKernel { class OpenCLCalcRBTorsionForceKernel : public CalcRBTorsionForceKernel {
public: public:
OpenCLCalcRBTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcRBTorsionForceKernel(name, platform), cl(cl), system(system) { OpenCLCalcRBTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcRBTorsionForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
} }
~OpenCLCalcRBTorsionForceKernel(); ~OpenCLCalcRBTorsionForceKernel();
/** /**
...@@ -287,6 +293,7 @@ public: ...@@ -287,6 +293,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
int numTorsions; int numTorsions;
bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
System& system; System& system;
OpenCLArray<mm_float8>* params; OpenCLArray<mm_float8>* params;
...@@ -299,8 +306,8 @@ private: ...@@ -299,8 +306,8 @@ private:
*/ */
class OpenCLCalcNonbondedForceKernel : public CalcNonbondedForceKernel { class OpenCLCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public: public:
OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform), cl(cl), OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform),
sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL) { hasInitializedKernel(false), cl(cl), sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL) {
} }
~OpenCLCalcNonbondedForceKernel(); ~OpenCLCalcNonbondedForceKernel();
/** /**
...@@ -325,6 +332,7 @@ public: ...@@ -325,6 +332,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
OpenCLContext& cl; OpenCLContext& cl;
bool hasInitializedKernel;
OpenCLArray<mm_float2>* sigmaEpsilon; OpenCLArray<mm_float2>* sigmaEpsilon;
OpenCLArray<mm_float4>* exceptionParams; OpenCLArray<mm_float4>* exceptionParams;
OpenCLArray<mm_int4>* exceptionIndices; OpenCLArray<mm_int4>* exceptionIndices;
...@@ -341,7 +349,7 @@ private: ...@@ -341,7 +349,7 @@ private:
class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel { class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public: public:
OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcCustomNonbondedForceKernel(name, platform), OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcCustomNonbondedForceKernel(name, platform),
hasCreatedKernels(false), cl(cl), params(NULL), globals(NULL), exceptionParams(NULL), exceptionIndices(NULL), system(system) { hasInitializedKernel(false), cl(cl), params(NULL), globals(NULL), exceptionParams(NULL), exceptionIndices(NULL), tabulatedFunctionParams(NULL), system(system) {
} }
~OpenCLCalcCustomNonbondedForceKernel(); ~OpenCLCalcCustomNonbondedForceKernel();
/** /**
...@@ -365,15 +373,17 @@ public: ...@@ -365,15 +373,17 @@ public:
*/ */
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
bool hasCreatedKernels; bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
OpenCLArray<mm_float4>* params; OpenCLArray<mm_float4>* params;
OpenCLArray<cl_float>* globals; OpenCLArray<cl_float>* globals;
OpenCLArray<mm_float4>* exceptionParams; OpenCLArray<mm_float4>* exceptionParams;
OpenCLArray<mm_int4>* exceptionIndices; OpenCLArray<mm_int4>* exceptionIndices;
OpenCLArray<mm_float4>* tabulatedFunctionParams;
cl::Kernel exceptionsKernel; cl::Kernel exceptionsKernel;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues; std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray<mm_float4>*> tabulatedFunctions;
System& system; System& system;
}; };
......
...@@ -2,13 +2,9 @@ ...@@ -2,13 +2,9 @@
* Compute custom nonbonded exceptions. * Compute custom nonbonded exceptions.
*/ */
__kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, float cutoffSquared, float4 periodicBoxSize, __global float4* forceBuffers, __global float* energyBuffer, __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, __global float4* forceBuffers, __global float* energyBuffer,
__global float4* posq, __global float4* params, __global int4* indices __global float4* posq, __global float4* params, __global int4* indices
#ifdef HAS_GLOBALS EXTRA_ARGUMENTS) {
, __constant float* globals) {
#else
) {
#endif
int index = get_global_id(0); int index = get_global_id(0);
float energy = 0.0f; float energy = 0.0f;
while (index < numExceptions) { while (index < numExceptions) {
...@@ -18,15 +14,15 @@ __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, ...@@ -18,15 +14,15 @@ __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions,
float4 exceptionParams = params[index]; float4 exceptionParams = params[index];
float4 delta = posq[atoms.y]-posq[atoms.x]; float4 delta = posq[atoms.y]-posq[atoms.x];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
delta.x -= floor(delta.x/periodicBoxSize.x+0.5f)*periodicBoxSize.x; delta.x -= floor(delta.x/PERIODIC_BOX_SIZE_X+0.5f)*PERIODIC_BOX_SIZE_X;
delta.y -= floor(delta.y/periodicBoxSize.y+0.5f)*periodicBoxSize.y; delta.y -= floor(delta.y/PERIODIC_BOX_SIZE_Y+0.5f)*PERIODIC_BOX_SIZE_Y;
delta.z -= floor(delta.z/periodicBoxSize.z+0.5f)*periodicBoxSize.z; delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z;
#endif #endif
// Compute the force. // Compute the force.
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
if (r2 > cutoffSquared) { if (r2 > CUTOFF_SQUARED) {
#else #else
{ {
#endif #endif
......
...@@ -241,8 +241,8 @@ int main() { ...@@ -241,8 +241,8 @@ int main() {
testExceptions(); testExceptions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
// testTabulatedFunction(true); testTabulatedFunction(true);
// testTabulatedFunction(false); testTabulatedFunction(false);
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment