Commit 72d59cbe authored by Peter Eastman's avatar Peter Eastman
Browse files

Finished OpenCL implementation of CustomNonbondedForce. Also implemented a few optimizations.

parent 2127b8dd
......@@ -84,6 +84,8 @@ public:
ExpressionTreeNode(const ExpressionTreeNode& node);
ExpressionTreeNode();
~ExpressionTreeNode();
bool operator==(const ExpressionTreeNode& node) const;
bool operator!=(const ExpressionTreeNode& node) const;
ExpressionTreeNode& operator=(const ExpressionTreeNode& node);
/**
* Get the Operation performed by this node.
......
......@@ -95,6 +95,12 @@ public:
* @param variable the variable with respect to which the derivate should be taken
*/
virtual ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const = 0;
virtual bool operator!=(const Operation& op) const {
return op.getId() != getId();
}
virtual bool operator==(const Operation& op) const {
return !(*this != op);
}
class Constant;
class Variable;
class Custom;
......@@ -149,6 +155,10 @@ public:
double getValue() const {
return value;
}
bool operator!=(const Operation& op) const {
const Constant* o = dynamic_cast<const Constant*>(&op);
return (o == NULL || o->value != value);
}
private:
double value;
};
......@@ -176,6 +186,10 @@ public:
return iter->second;
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
bool operator!=(const Operation& op) const {
const Variable* o = dynamic_cast<const Variable*>(&op);
return (o == NULL || o->name != name);
}
private:
std::string name;
};
......@@ -214,6 +228,10 @@ public:
const std::vector<int>& getDerivOrder() const {
return derivOrder;
}
bool operator!=(const Operation& op) const {
const Custom* o = dynamic_cast<const Custom*>(&op);
return (o == NULL || o->name != name || o->isDerivative != isDerivative || o->derivOrder != derivOrder);
}
private:
std::string name;
CustomFunction* function;
......@@ -708,6 +726,10 @@ public:
double getValue() const {
return value;
}
bool operator!=(const Operation& op) const {
const AddConstant* o = dynamic_cast<const AddConstant*>(&op);
return (o == NULL || o->value != value);
}
private:
double value;
};
......@@ -737,6 +759,10 @@ public:
double getValue() const {
return value;
}
bool operator!=(const Operation& op) const {
const MultiplyConstant* o = dynamic_cast<const MultiplyConstant*>(&op);
return (o == NULL || o->value != value);
}
private:
double value;
};
......@@ -766,6 +792,10 @@ public:
double getValue() const {
return value;
}
bool operator!=(const Operation& op) const {
const PowerConstant* o = dynamic_cast<const PowerConstant*>(&op);
return (o == NULL || o->value != value);
}
private:
double value;
};
......
......@@ -48,6 +48,11 @@ class ExpressionProgram;
class LEPTON_EXPORT ParsedExpression {
public:
/**
* Create an uninitialized ParsedExpression. This exists so that ParsedExpressions can be put in STL containers.
* Doing anything with it will produce an exception.
*/
ParsedExpression();
/**
* Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class
* to parse expression.
......
......@@ -70,6 +70,19 @@ ExpressionTreeNode::~ExpressionTreeNode() {
delete operation;
}
bool ExpressionTreeNode::operator!=(const ExpressionTreeNode& node) const {
if (node.getOperation() != getOperation())
return true;
for (int i = 0; i < (int) getChildren().size(); i++)
if (getChildren()[i] != node.getChildren()[i])
return true;
return false;
}
bool ExpressionTreeNode::operator==(const ExpressionTreeNode& node) const {
return !(*this != node);
}
ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node) {
if (operation != NULL)
delete operation;
......
......@@ -38,10 +38,15 @@
using namespace Lepton;
using namespace std;
ParsedExpression::ParsedExpression() : rootNode(ExpressionTreeNode()) {
}
ParsedExpression::ParsedExpression(const ExpressionTreeNode& rootNode) : rootNode(rootNode) {
}
const ExpressionTreeNode& ParsedExpression::getRootNode() const {
if (&rootNode.getOperation() == NULL)
throw Exception("Illegal call to an initialized ParsedExpression");
return rootNode;
}
......
......@@ -27,7 +27,6 @@
#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "lepton/Operation.h"
#include <sstream>
using namespace OpenMM;
using namespace Lepton;
......@@ -46,69 +45,152 @@ static string intToString(int value) {
return s.str();
}
string OpenCLExpressionUtilities::createExpression(const ParsedExpression& expression, const map<string, string>& variables) {
return processExpression(expression.getRootNode(), variables);
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
stringstream out;
vector<pair<ExpressionTreeNode, string> > temps;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, variables, functions, prefix, functionParams);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
}
string OpenCLExpressionUtilities::processExpression(const ExpressionTreeNode& node, const map<string, string>& variables) {
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const map<string, string>& variables, const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, variables, functions, prefix, functionParams);
string name = prefix+intToString(temps.size());
out << "float " << name << " = ";
switch (node.getOperation().getId()) {
case Operation::CONSTANT:
return doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
out << doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
break;
case Operation::VARIABLE:
{
map<string, string>::const_iterator iter = variables.find(node.getOperation().getName());
if (iter == variables.end())
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
return iter->second;
out << iter->second;
break;
}
case Operation::CUSTOM:
{
int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++)
;
if (i == functions.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
out << "0.0f;\n";
out << "{\n";
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "int index = (int) (floor((x-params.x)*params.z));\n";
out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "x = (x-params.x)*params.z-index;\n";
if (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 0)
out << name << " = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));\n";
else
out << name << " = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;\n";
out << "}\n";
out << "}";
break;
}
case Operation::ADD:
return "("+processExpression(node.getChildren()[0], variables)+")+("+processExpression(node.getChildren()[1], variables)+")";
out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
break;
case Operation::SUBTRACT:
return "("+processExpression(node.getChildren()[0], variables)+")-("+processExpression(node.getChildren()[1], variables)+")";
out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
break;
case Operation::MULTIPLY:
return "("+processExpression(node.getChildren()[0], variables)+")*("+processExpression(node.getChildren()[1], variables)+")";
out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
break;
case Operation::DIVIDE:
return "("+processExpression(node.getChildren()[0], variables)+")/("+processExpression(node.getChildren()[1], variables)+")";
out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
break;
case Operation::POWER:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), ("+processExpression(node.getChildren()[1], variables)+"))";
out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::NEGATE:
return "-("+processExpression(node.getChildren()[0], variables)+")";
out << "-" << getTempName(node.getChildren()[0], temps);
break;
case Operation::SQRT:
return "sqrt("+processExpression(node.getChildren()[0], variables)+")";
out << "sqrt(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::EXP:
return "exp("+processExpression(node.getChildren()[0], variables)+")";
out << "exp(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::LOG:
return "log("+processExpression(node.getChildren()[0], variables)+")";
out << "log(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SIN:
return "sin("+processExpression(node.getChildren()[0], variables)+")";
out << "sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COS:
return "cos("+processExpression(node.getChildren()[0], variables)+")";
out << "cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SEC:
return "1.0f/cos("+processExpression(node.getChildren()[0], variables)+")";
out << "1.0f/cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::CSC:
return "1.0f/sin("+processExpression(node.getChildren()[0], variables)+")";
out << "1.0f/sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::TAN:
return "tan("+processExpression(node.getChildren()[0], variables)+")";
out << "tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COT:
return "1.0f/tan("+processExpression(node.getChildren()[0], variables)+")";
out << "1.0f/tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ASIN:
return "asin("+processExpression(node.getChildren()[0], variables)+")";
out << "asin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ACOS:
return "acos("+processExpression(node.getChildren()[0], variables)+")";
out << "acos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ATAN:
return "atan("+processExpression(node.getChildren()[0], variables)+")";
out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SQUARE:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), 2.0f)";
{
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg;
break;
}
case Operation::CUBE:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), 3.0f)";
{
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg << "*" << arg;
break;
}
case Operation::RECIPROCAL:
return "1.0f/("+processExpression(node.getChildren()[0], variables)+")";
out << "1.0f/" << getTempName(node.getChildren()[0], temps);
break;
case Operation::ADD_CONSTANT:
return doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue())+"+("+processExpression(node.getChildren()[0], variables)+")";
out << doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
break;
case Operation::MULTIPLY_CONSTANT:
return doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue())+"*("+processExpression(node.getChildren()[0], variables)+")";
out << doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
break;
case Operation::POWER_CONSTANT:
return "pow(("+processExpression(node.getChildren()[0], variables)+"), "+doubleToString(dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue())+")";
out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << doubleToString(dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue()) << ")";
break;
default:
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
}
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
out << ";\n";
temps.push_back(make_pair(node, name));
}
string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return temps[i].second;
stringstream out;
out << "Internal error: No temporary variable for expression node: " << node;
throw OpenMMException(out.str());
}
......@@ -30,7 +30,9 @@
#include "lepton/ExpressionTreeNode.h"
#include "lepton/ParsedExpression.h"
#include <map>
#include <sstream>
#include <string>
#include <utility>
namespace OpenMM {
......@@ -41,9 +43,13 @@ namespace OpenMM {
class OpenCLExpressionUtilities {
public:
static std::string createExpression(const Lepton::ParsedExpression& expression, const std::map<std::string, std::string>& variables);
static std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
private:
static std::string processExpression(const Lepton::ExpressionTreeNode& node, const std::map<std::string, std::string>& variables);
static void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
static std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
};
} // namespace OpenMM
......
This diff is collapsed.
......@@ -150,8 +150,8 @@ private:
*/
class OpenCLCalcHarmonicBondForceKernel : public CalcHarmonicBondForceKernel {
public:
OpenCLCalcHarmonicBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) :
CalcHarmonicBondForceKernel(name, platform), cl(cl), system(system), params(NULL), indices(NULL) {
OpenCLCalcHarmonicBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicBondForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system), params(NULL), indices(NULL) {
}
~OpenCLCalcHarmonicBondForceKernel();
/**
......@@ -176,6 +176,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
int numBonds;
bool hasInitializedKernel;
OpenCLContext& cl;
System& system;
OpenCLArray<mm_float2>* params;
......@@ -188,7 +189,8 @@ private:
*/
class OpenCLCalcHarmonicAngleForceKernel : public CalcHarmonicAngleForceKernel {
public:
OpenCLCalcHarmonicAngleForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicAngleForceKernel(name, platform), cl(cl), system(system) {
OpenCLCalcHarmonicAngleForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcHarmonicAngleForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
}
~OpenCLCalcHarmonicAngleForceKernel();
/**
......@@ -213,6 +215,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
int numAngles;
bool hasInitializedKernel;
OpenCLContext& cl;
System& system;
OpenCLArray<mm_float2>* params;
......@@ -225,7 +228,8 @@ private:
*/
class OpenCLCalcPeriodicTorsionForceKernel : public CalcPeriodicTorsionForceKernel {
public:
OpenCLCalcPeriodicTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcPeriodicTorsionForceKernel(name, platform), cl(cl), system(system) {
OpenCLCalcPeriodicTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcPeriodicTorsionForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
}
~OpenCLCalcPeriodicTorsionForceKernel();
/**
......@@ -250,6 +254,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
int numTorsions;
bool hasInitializedKernel;
OpenCLContext& cl;
System& system;
OpenCLArray<mm_float4>* params;
......@@ -262,7 +267,8 @@ private:
*/
class OpenCLCalcRBTorsionForceKernel : public CalcRBTorsionForceKernel {
public:
OpenCLCalcRBTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcRBTorsionForceKernel(name, platform), cl(cl), system(system) {
OpenCLCalcRBTorsionForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcRBTorsionForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), system(system) {
}
~OpenCLCalcRBTorsionForceKernel();
/**
......@@ -287,6 +293,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
int numTorsions;
bool hasInitializedKernel;
OpenCLContext& cl;
System& system;
OpenCLArray<mm_float8>* params;
......@@ -299,8 +306,8 @@ private:
*/
class OpenCLCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public:
OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform), cl(cl),
sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL) {
OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL) {
}
~OpenCLCalcNonbondedForceKernel();
/**
......@@ -325,6 +332,7 @@ public:
double executeEnergy(ContextImpl& context);
private:
OpenCLContext& cl;
bool hasInitializedKernel;
OpenCLArray<mm_float2>* sigmaEpsilon;
OpenCLArray<mm_float4>* exceptionParams;
OpenCLArray<mm_int4>* exceptionIndices;
......@@ -341,7 +349,7 @@ private:
class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public:
OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcCustomNonbondedForceKernel(name, platform),
hasCreatedKernels(false), cl(cl), params(NULL), globals(NULL), exceptionParams(NULL), exceptionIndices(NULL), system(system) {
hasInitializedKernel(false), cl(cl), params(NULL), globals(NULL), exceptionParams(NULL), exceptionIndices(NULL), tabulatedFunctionParams(NULL), system(system) {
}
~OpenCLCalcCustomNonbondedForceKernel();
/**
......@@ -365,15 +373,17 @@ public:
*/
double executeEnergy(ContextImpl& context);
private:
bool hasCreatedKernels;
bool hasInitializedKernel;
OpenCLContext& cl;
OpenCLArray<mm_float4>* params;
OpenCLArray<cl_float>* globals;
OpenCLArray<mm_float4>* exceptionParams;
OpenCLArray<mm_int4>* exceptionIndices;
OpenCLArray<mm_float4>* tabulatedFunctionParams;
cl::Kernel exceptionsKernel;
std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray<mm_float4>*> tabulatedFunctions;
System& system;
};
......
......@@ -2,13 +2,9 @@
* Compute custom nonbonded exceptions.
*/
__kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, float cutoffSquared, float4 periodicBoxSize, __global float4* forceBuffers, __global float* energyBuffer,
__kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions, __global float4* forceBuffers, __global float* energyBuffer,
__global float4* posq, __global float4* params, __global int4* indices
#ifdef HAS_GLOBALS
, __constant float* globals) {
#else
) {
#endif
EXTRA_ARGUMENTS) {
int index = get_global_id(0);
float energy = 0.0f;
while (index < numExceptions) {
......@@ -18,15 +14,15 @@ __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions,
float4 exceptionParams = params[index];
float4 delta = posq[atoms.y]-posq[atoms.x];
#ifdef USE_PERIODIC
delta.x -= floor(delta.x/periodicBoxSize.x+0.5f)*periodicBoxSize.x;
delta.y -= floor(delta.y/periodicBoxSize.y+0.5f)*periodicBoxSize.y;
delta.z -= floor(delta.z/periodicBoxSize.z+0.5f)*periodicBoxSize.z;
delta.x -= floor(delta.x/PERIODIC_BOX_SIZE_X+0.5f)*PERIODIC_BOX_SIZE_X;
delta.y -= floor(delta.y/PERIODIC_BOX_SIZE_Y+0.5f)*PERIODIC_BOX_SIZE_Y;
delta.z -= floor(delta.z/PERIODIC_BOX_SIZE_Z+0.5f)*PERIODIC_BOX_SIZE_Z;
#endif
// Compute the force.
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
if (r2 > cutoffSquared) {
if (r2 > CUTOFF_SQUARED) {
#else
{
#endif
......
......@@ -241,8 +241,8 @@ int main() {
testExceptions();
testCutoff();
testPeriodic();
// testTabulatedFunction(true);
// testTabulatedFunction(false);
testTabulatedFunction(true);
testTabulatedFunction(false);
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment