"vscode:/vscode.git/clone" did not exist on "b9b7ce348bc797e040fa7f7fd7932f91f02b5f8e"
Commit 72d59cbe authored by Peter Eastman's avatar Peter Eastman
Browse files

Finished OpenCL implementation of CustomNonbondedForce. Also implemented a few optimizations.

parent 2127b8dd
...@@ -84,6 +84,8 @@ public: ...@@ -84,6 +84,8 @@ public:
ExpressionTreeNode(const ExpressionTreeNode& node); ExpressionTreeNode(const ExpressionTreeNode& node);
ExpressionTreeNode(); ExpressionTreeNode();
~ExpressionTreeNode(); ~ExpressionTreeNode();
bool operator==(const ExpressionTreeNode& node) const;
bool operator!=(const ExpressionTreeNode& node) const;
ExpressionTreeNode& operator=(const ExpressionTreeNode& node); ExpressionTreeNode& operator=(const ExpressionTreeNode& node);
/** /**
* Get the Operation performed by this node. * Get the Operation performed by this node.
......
...@@ -95,6 +95,12 @@ public: ...@@ -95,6 +95,12 @@ public:
* @param variable the variable with respect to which the derivate should be taken * @param variable the variable with respect to which the derivate should be taken
*/ */
virtual ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const = 0; virtual ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const = 0;
virtual bool operator!=(const Operation& op) const {
return op.getId() != getId();
}
virtual bool operator==(const Operation& op) const {
return !(*this != op);
}
class Constant; class Constant;
class Variable; class Variable;
class Custom; class Custom;
...@@ -149,6 +155,10 @@ public: ...@@ -149,6 +155,10 @@ public:
double getValue() const { double getValue() const {
return value; return value;
} }
bool operator!=(const Operation& op) const {
const Constant* o = dynamic_cast<const Constant*>(&op);
return (o == NULL || o->value != value);
}
private: private:
double value; double value;
}; };
...@@ -176,6 +186,10 @@ public: ...@@ -176,6 +186,10 @@ public:
return iter->second; return iter->second;
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
bool operator!=(const Operation& op) const {
const Variable* o = dynamic_cast<const Variable*>(&op);
return (o == NULL || o->name != name);
}
private: private:
std::string name; std::string name;
}; };
...@@ -214,6 +228,10 @@ public: ...@@ -214,6 +228,10 @@ public:
const std::vector<int>& getDerivOrder() const { const std::vector<int>& getDerivOrder() const {
return derivOrder; return derivOrder;
} }
bool operator!=(const Operation& op) const {
const Custom* o = dynamic_cast<const Custom*>(&op);
return (o == NULL || o->name != name || o->isDerivative != isDerivative || o->derivOrder != derivOrder);
}
private: private:
std::string name; std::string name;
CustomFunction* function; CustomFunction* function;
...@@ -708,6 +726,10 @@ public: ...@@ -708,6 +726,10 @@ public:
double getValue() const { double getValue() const {
return value; return value;
} }
bool operator!=(const Operation& op) const {
const AddConstant* o = dynamic_cast<const AddConstant*>(&op);
return (o == NULL || o->value != value);
}
private: private:
double value; double value;
}; };
...@@ -737,6 +759,10 @@ public: ...@@ -737,6 +759,10 @@ public:
double getValue() const { double getValue() const {
return value; return value;
} }
bool operator!=(const Operation& op) const {
const MultiplyConstant* o = dynamic_cast<const MultiplyConstant*>(&op);
return (o == NULL || o->value != value);
}
private: private:
double value; double value;
}; };
...@@ -766,6 +792,10 @@ public: ...@@ -766,6 +792,10 @@ public:
double getValue() const { double getValue() const {
return value; return value;
} }
bool operator!=(const Operation& op) const {
const PowerConstant* o = dynamic_cast<const PowerConstant*>(&op);
return (o == NULL || o->value != value);
}
private: private:
double value; double value;
}; };
......
...@@ -48,6 +48,11 @@ class ExpressionProgram; ...@@ -48,6 +48,11 @@ class ExpressionProgram;
class LEPTON_EXPORT ParsedExpression { class LEPTON_EXPORT ParsedExpression {
public: public:
/**
* Create an uninitialized ParsedExpression. This exists so that ParsedExpressions can be put in STL containers.
* Doing anything with it will produce an exception.
*/
ParsedExpression();
/** /**
* Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class * Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class
* to parse expression. * to parse expression.
......
...@@ -70,6 +70,19 @@ ExpressionTreeNode::~ExpressionTreeNode() { ...@@ -70,6 +70,19 @@ ExpressionTreeNode::~ExpressionTreeNode() {
delete operation; delete operation;
} }
bool ExpressionTreeNode::operator!=(const ExpressionTreeNode& node) const {
if (node.getOperation() != getOperation())
return true;
for (int i = 0; i < (int) getChildren().size(); i++)
if (getChildren()[i] != node.getChildren()[i])
return true;
return false;
}
bool ExpressionTreeNode::operator==(const ExpressionTreeNode& node) const {
return !(*this != node);
}
ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node) { ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node) {
if (operation != NULL) if (operation != NULL)
delete operation; delete operation;
......
...@@ -38,10 +38,15 @@ ...@@ -38,10 +38,15 @@
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
ParsedExpression::ParsedExpression() : rootNode(ExpressionTreeNode()) {
}
ParsedExpression::ParsedExpression(const ExpressionTreeNode& rootNode) : rootNode(rootNode) { ParsedExpression::ParsedExpression(const ExpressionTreeNode& rootNode) : rootNode(rootNode) {
} }
const ExpressionTreeNode& ParsedExpression::getRootNode() const { const ExpressionTreeNode& ParsedExpression::getRootNode() const {
if (&rootNode.getOperation() == NULL)
throw Exception("Illegal call to an initialized ParsedExpression");
return rootNode; return rootNode;
} }
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "OpenCLExpressionUtilities.h" #include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
#include <sstream>
using namespace OpenMM; using namespace OpenMM;
using namespace Lepton; using namespace Lepton;
...@@ -46,69 +45,152 @@ static string intToString(int value) { ...@@ -46,69 +45,152 @@ static string intToString(int value) {
return s.str(); return s.str();
} }
string OpenCLExpressionUtilities::createExpression(const ParsedExpression& expression, const map<string, string>& variables) { string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
return processExpression(expression.getRootNode(), variables); const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
stringstream out;
vector<pair<ExpressionTreeNode, string> > temps;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, variables, functions, prefix, functionParams);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
} }
string OpenCLExpressionUtilities::processExpression(const ExpressionTreeNode& node, const map<string, string>& variables) { void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const map<string, string>& variables, const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, variables, functions, prefix, functionParams);
string name = prefix+intToString(temps.size());
out << "float " << name << " = ";
switch (node.getOperation().getId()) { switch (node.getOperation().getId()) {
case Operation::CONSTANT: case Operation::CONSTANT:
return doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue()); out << doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
break;
case Operation::VARIABLE: case Operation::VARIABLE:
{ {
map<string, string>::const_iterator iter = variables.find(node.getOperation().getName()); map<string, string>::const_iterator iter = variables.find(node.getOperation().getName());
if (iter == variables.end()) if (iter == variables.end())
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
return iter->second; out << iter->second;
break;
}
case Operation::CUSTOM:
{
int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++)
;
if (i == functions.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
out << "0.0f;\n";
out << "{\n";
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "int index = (int) (floor((x-params.x)*params.z));\n";
out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "x = (x-params.x)*params.z-index;\n";
if (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 0)
out << name << " = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));\n";
else
out << name << " = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;\n";
out << "}\n";
out << "}";
break;
} }
case Operation::ADD: case Operation::ADD:
return "("+processExpression(node.getChildren()[0], variables)+")+("+processExpression(node.getChildren()[1], variables)+")"; out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
break;
case Operation::SUBTRACT: case Operation::SUBTRACT:
return "("+processExpression(node.getChildren()[0], variables)+")-("+processExpression(node.getChildren()[1], variables)+")"; out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
break;
case Operation::MULTIPLY: case Operation::MULTIPLY:
return "("+processExpression(node.getChildren()[0], variables)+")*("+processExpression(node.getChildren()[1], variables)+")"; out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
break;
case Operation::DIVIDE: case Operation::DIVIDE:
return "("+processExpression(node.getChildren()[0], variables)+")/("+processExpression(node.getChildren()[1], variables)+")"; out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
break;
case Operation::POWER: case Operation::POWER:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), ("+processExpression(node.getChildren()[1], variables)+"))"; out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::NEGATE: case Operation::NEGATE:
return "-("+processExpression(node.getChildren()[0], variables)+")"; out << "-" << getTempName(node.getChildren()[0], temps);
break;
case Operation::SQRT: case Operation::SQRT:
return "sqrt("+processExpression(node.getChildren()[0], variables)+")"; out << "sqrt(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::EXP: case Operation::EXP:
return "exp("+processExpression(node.getChildren()[0], variables)+")"; out << "exp(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::LOG: case Operation::LOG:
return "log("+processExpression(node.getChildren()[0], variables)+")"; out << "log(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SIN: case Operation::SIN:
return "sin("+processExpression(node.getChildren()[0], variables)+")"; out << "sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COS: case Operation::COS:
return "cos("+processExpression(node.getChildren()[0], variables)+")"; out << "cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SEC: case Operation::SEC:
return "1.0f/cos("+processExpression(node.getChildren()[0], variables)+")"; out << "1.0f/cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::CSC: case Operation::CSC:
return "1.0f/sin("+processExpression(node.getChildren()[0], variables)+")"; out << "1.0f/sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::TAN: case Operation::TAN:
return "tan("+processExpression(node.getChildren()[0], variables)+")"; out << "tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COT: case Operation::COT:
return "1.0f/tan("+processExpression(node.getChildren()[0], variables)+")"; out << "1.0f/tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ASIN: case Operation::ASIN:
return "asin("+processExpression(node.getChildren()[0], variables)+")"; out << "asin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ACOS: case Operation::ACOS:
return "acos("+processExpression(node.getChildren()[0], variables)+")"; out << "acos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ATAN: case Operation::ATAN:
return "atan("+processExpression(node.getChildren()[0], variables)+")"; out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SQUARE: case Operation::SQUARE:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), 2.0f)"; {
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg;
break;
}
case Operation::CUBE: case Operation::CUBE:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), 3.0f)"; {
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg << "*" << arg;
break;
}
case Operation::RECIPROCAL: case Operation::RECIPROCAL:
return "1.0f/("+processExpression(node.getChildren()[0], variables)+")"; out << "1.0f/" << getTempName(node.getChildren()[0], temps);
break;
case Operation::ADD_CONSTANT: case Operation::ADD_CONSTANT:
return doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue())+"+("+processExpression(node.getChildren()[0], variables)+")"; out << doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
break;
case Operation::MULTIPLY_CONSTANT: case Operation::MULTIPLY_CONSTANT:
return doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue())+"*("+processExpression(node.getChildren()[0], variables)+")"; out << doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
break;
case Operation::POWER_CONSTANT: case Operation::POWER_CONSTANT:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), "+doubleToString(dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue())+")"; out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << doubleToString(dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue()) << ")";
break;
default:
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
} }
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName()); out << ";\n";
temps.push_back(make_pair(node, name));
}
string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return temps[i].second;
stringstream out;
out << "Internal error: No temporary variable for expression node: " << node;
throw OpenMMException(out.str());
} }
...@@ -30,7 +30,9 @@ ...@@ -30,7 +30,9 @@
#include "lepton/ExpressionTreeNode.h" #include "lepton/ExpressionTreeNode.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include <map> #include <map>
#include <sstream>
#include <string> #include <string>
#include <utility>
namespace OpenMM { namespace OpenMM {
...@@ -41,9 +43,13 @@ namespace OpenMM { ...@@ -41,9 +43,13 @@ namespace OpenMM {
class OpenCLExpressionUtilities { class OpenCLExpressionUtilities {
public: public:
static std::string createExpression(const Lepton::ParsedExpression& expression, const std::map<std::string, std::string>& variables); static std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
private: private:
static std::string processExpression(const Lepton::ExpressionTreeNode& node, const std::map<std::string, std::string>& variables); static void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
static std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
}; };
} // namespace OpenMM } // namespace OpenMM
......
This diff is collapsed.
...@@ -150,8 +150,8 @@ private: ...@@ -150,8 +150,8 @@ private:
*/ */
class OpenCLCalcHarmonicBondForceKernel : public CalcHarmonicBondForceKernel { class OpenCLCalcHarmonicBondForceKernel : public CalcHarmonicBondForceKernel {
public: public:
OpenCLCalcHarmonicBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : OpenCLCalcHarmonicBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicBondForceKernel(name, platform),
CalcHarmonicBondForceKernel(name, platform), cl(cl), system(system), params(NULL), indices(NULL) { hasInitializedKernel(false), cl(cl), system(system), params(NULL), indices(NULL) {
} }
~OpenCLCalcHarmonicBondForceKernel(); ~OpenCLCalcHarmonicBondForceKernel();
/** /**
...@@ -176,6 +176,7 @@ public: ...@@ -176,6 +176,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
int numBonds; int numBonds;
bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
System& system; System& system;
OpenCLArray<mm_float2>* params; OpenCLArray<mm_float2>* params;
...@@ -188,7 +189,8 @@ private: ...@@ -188,7 +189,8 @@ private:
*/ */
class OpenCLCalcHarmonicAngleForceKernel : public CalcHarmonicAngleForceKernel { class OpenCLCalcHarmonicAngleForceKernel : public CalcHarmonicAngleForceKernel {
public: public:
OpenCLCalcHarmonicAngleForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicAngleForceKernel(name, platform), cl(cl), system(system) { OpenCLCalcHarmonicAngleForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicAngleForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
} }
~OpenCLCalcHarmonicAngleForceKernel(); ~OpenCLCalcHarmonicAngleForceKernel();
/** /**
...@@ -213,6 +215,7 @@ public: ...@@ -213,6 +215,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
int numAngles; int numAngles;
bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
System& system; System& system;
OpenCLArray<mm_float2>* params; OpenCLArray<mm_float2>* params;
...@@ -225,7 +228,8 @@ private: ...@@ -225,7 +228,8 @@ private:
*/ */
class OpenCLCalcPeriodicTorsionForceKernel : public CalcPeriodicTorsionForceKernel { class OpenCLCalcPeriodicTorsionForceKernel : public CalcPeriodicTorsionForceKernel {
public: public:
OpenCLCalcPeriodicTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcPeriodicTorsionForceKernel(name, platform), cl(cl), system(system) { OpenCLCalcPeriodicTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcPeriodicTorsionForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
} }
~OpenCLCalcPeriodicTorsionForceKernel(); ~OpenCLCalcPeriodicTorsionForceKernel();
/** /**
...@@ -250,6 +254,7 @@ public: ...@@ -250,6 +254,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
int numTorsions; int numTorsions;
bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
System& system; System& system;
OpenCLArray<mm_float4>* params; OpenCLArray<mm_float4>* params;
...@@ -262,7 +267,8 @@ private: ...@@ -262,7 +267,8 @@ private:
*/ */
class OpenCLCalcRBTorsionForceKernel : public CalcRBTorsionForceKernel { class OpenCLCalcRBTorsionForceKernel : public CalcRBTorsionForceKernel {
public: public:
OpenCLCalcRBTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcRBTorsionForceKernel(name, platform), cl(cl), system(system) { OpenCLCalcRBTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcRBTorsionForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
} }
~OpenCLCalcRBTorsionForceKernel(); ~OpenCLCalcRBTorsionForceKernel();
/** /**
...@@ -287,6 +293,7 @@ public: ...@@ -287,6 +293,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
int numTorsions; int numTorsions;
bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
System& system; System& system;
OpenCLArray<mm_float8>* params; OpenCLArray<mm_float8>* params;
...@@ -299,8 +306,8 @@ private: ...@@ -299,8 +306,8 @@ private:
*/ */
class OpenCLCalcNonbondedForceKernel : public CalcNonbondedForceKernel { class OpenCLCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public: public:
OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform), cl(cl), OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform),
sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL) { hasInitializedKernel(false), cl(cl), sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL) {
} }
~OpenCLCalcNonbondedForceKernel(); ~OpenCLCalcNonbondedForceKernel();
/** /**
...@@ -325,6 +332,7 @@ public: ...@@ -325,6 +332,7 @@ public:
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
OpenCLContext& cl; OpenCLContext& cl;
bool hasInitializedKernel;
OpenCLArray<mm_float2>* sigmaEpsilon; OpenCLArray<mm_float2>* sigmaEpsilon;
OpenCLArray<mm_float4>* exceptionParams; OpenCLArray<mm_float4>* exceptionParams;
OpenCLArray<mm_int4>* exceptionIndices; OpenCLArray<mm_int4>* exceptionIndices;
...@@ -341,7 +349,7 @@ private: ...@@ -341,7 +349,7 @@ private:
class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel { class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public: public:
OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcCustomNonbondedForceKernel(name, platform), OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcCustomNonbondedForceKernel(name, platform),
hasCreatedKernels(false), cl(cl), params(NULL), globals(NULL), exceptionParams(NULL), exceptionIndices(NULL), system(system) { hasInitializedKernel(false), cl(cl), params(NULL), globals(NULL), exceptionParams(NULL), exceptionIndices(NULL), tabulatedFunctionParams(NULL), system(system) {
} }
~OpenCLCalcCustomNonbondedForceKernel(); ~OpenCLCalcCustomNonbondedForceKernel();
/** /**
...@@ -365,15 +373,17 @@ public: ...@@ -365,15 +373,17 @@ public:
*/ */
double executeEnergy(ContextImpl& context); double executeEnergy(ContextImpl& context);
private: private:
bool hasCreatedKernels; bool hasInitializedKernel;
OpenCLContext& cl; OpenCLContext& cl;
OpenCLArray<mm_float4>* params; OpenCLArray<mm_float4>* params;
OpenCLArray<cl_float>* globals; OpenCLArray<cl_float>* globals;
OpenCLArray<mm_float4>* exceptionParams; OpenCLArray<mm_float4>* exceptionParams;
OpenCLArray<mm_int4>* exceptionIndices; OpenCLArray<mm_int4>* exceptionIndices;
OpenCLArray<mm_float4>* tabulatedFunctionParams;
cl::Kernel exceptionsKernel; cl::Kernel exceptionsKernel;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues; std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray<mm_float4>*> tabulatedFunctions;
System& system; System& system;
}; };
......
...@@ -2,13 +2,9 @@ ...@@ -2,13 +2,9 @@
* Compute custom nonbonded exceptions. * Compute custom nonbonded exceptions.
*/ */
__kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, float cutoffSquared, float4 periodicBoxSize, __global float4* forceBuffers, __global float* energyBuffer, __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, __global float4* forceBuffers, __global float* energyBuffer,
__global float4* posq, __global float4* params, __global int4* indices __global float4* posq, __global float4* params, __global int4* indices
#ifdef HAS_GLOBALS EXTRA_ARGUMENTS) {
, __constant float* globals) {
#else
) {
#endif
int index = get_global_id(0); int index = get_global_id(0);
float energy = 0.0f; float energy = 0.0f;
while (index < numExceptions) { while (index < numExceptions) {
...@@ -18,15 +14,15 @@ __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, ...@@ -18,15 +14,15 @@ __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions,
float4 exceptionParams = params[index]; float4 exceptionParams = params[index];
float4 delta = posq[atoms.y]-posq[atoms.x]; float4 delta = posq[atoms.y]-posq[atoms.x];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
delta.x -= floor(delta.x/periodicBoxSize.x+0.5f)*periodicBoxSize.x; delta.x -= floor(delta.x/PERIODIC_BOX_SIZE_X+0.5f)*PERIODIC_BOX_SIZE_X;
delta.y -= floor(delta.y/periodicBoxSize.y+0.5f)*periodicBoxSize.y; delta.y -= floor(delta.y/PERIODIC_BOX_SIZE_Y+0.5f)*PERIODIC_BOX_SIZE_Y;
delta.z -= floor(delta.z/periodicBoxSize.z+0.5f)*periodicBoxSize.z; delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z;
#endif #endif
// Compute the force. // Compute the force.
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
if (r2 > cutoffSquared) { if (r2 > CUTOFF_SQUARED) {
#else #else
{ {
#endif #endif
......
...@@ -241,8 +241,8 @@ int main() { ...@@ -241,8 +241,8 @@ int main() {
testExceptions(); testExceptions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
// testTabulatedFunction(true); testTabulatedFunction(true);
// testTabulatedFunction(false); testTabulatedFunction(false);
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment