Commit b34864fb authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented ParsedExpression::optimize(). Also added lots of comments.

parent 32e98eb2
......@@ -57,7 +57,7 @@ SUBDIRS (tests examples)
# The source is organized into subdirectories, but we handle them all from
# this CMakeLists file rather than letting CMake visit them as SUBDIRS.
SET(OPENMM_SOURCE_SUBDIRS . openmmapi olla libraries/jama libraries/quern platforms/reference)
SET(OPENMM_SOURCE_SUBDIRS . openmmapi olla libraries/jama libraries/quern libraries/lepton platforms/reference)
# The build system will set ARCH64 for 64 bit builds, which require
# use of the lib64/ library directories rather than lib/.
......
......@@ -40,17 +40,58 @@ namespace Lepton {
class Operation;
/**
* This class represents a node in the abstract syntax tree representation of an expression.
* Each node is defined by an Operation and a set of children. When the expression is
* evaluated, each child is first evaluated in order, then the resulting values are passed
* as the arguments to the Operation's evaluate() method.
*/
class LEPTON_EXPORT ExpressionTreeNode {
public:
/**
* Create a new ExpressionTreeNode.
*
* @param operation the operation for this node. The ExpressionTreeNode takes over ownership
* of this object, and deletes it when the node is itself deleted.
* @param children the children of this node
*/
ExpressionTreeNode(Operation* operation, const std::vector<ExpressionTreeNode>& children);
/**
* Create a new ExpressionTreeNode with two children.
*
* @param operation the operation for this node. The ExpressionTreeNode takes over ownership
* of this object, and deletes it when the node is itself deleted.
* @param child1 the first child of this node
* @param child2 the second child of this node
*/
ExpressionTreeNode(Operation* operation, const ExpressionTreeNode& child1, const ExpressionTreeNode& child2);
/**
* Create a new ExpressionTreeNode with one child.
*
* @param operation the operation for this node. The ExpressionTreeNode takes over ownership
* of this object, and deletes it when the node is itself deleted.
* @param child the child of this node
*/
ExpressionTreeNode(Operation* operation, const ExpressionTreeNode& child);
/**
* Create a new ExpressionTreeNode with no children.
*
* @param operation the operation for this node. The ExpressionTreeNode takes over ownership
* of this object, and deletes it when the node is itself deleted.
*/
ExpressionTreeNode(Operation* operation);
ExpressionTreeNode(const ExpressionTreeNode& node);
ExpressionTreeNode();
~ExpressionTreeNode();
ExpressionTreeNode& operator=(const ExpressionTreeNode& node);
/**
* Get the Operation performed by this node.
*/
const Operation& getOperation() const;
/**
* Get this node's child nodes.
*/
const std::vector<ExpressionTreeNode>& getChildren() const;
private:
Operation* operation;
......
......@@ -41,14 +41,45 @@
namespace Lepton {
/**
* An Operation represents a single step in the evaluation of an expression, such as a function,
* an operator, or a constant value. Each Operation takes some number of values as arguments
* and produces a single value.
*
* This is an abstract class with subclasses for specific operations.
*/
class LEPTON_EXPORT Operation {
public:
/**
* This enumeration lists all Operation subclasses. This is provided so that switch statements
* can be used when processing or analyzing parsed expressions.
*/
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN};
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SQUARE, CUBE, RECIPROCAL};
/**
* Get the name of this Operation.
*/
virtual std::string getName() const = 0;
/**
* Get this Operation's ID.
*/
virtual Id getId() const = 0;
/**
* Get the number of arguments this operation expects.
*/
virtual int getNumArguments() const = 0;
/**
* Create a clone of this Operation.
*/
virtual Operation* clone() const = 0;
/**
* Perform the computation represented by this operation.
*
* @param args the array of arguments
* @param variables a map containing the values of all variables
* @return the result of performing the computation.
*/
virtual double evaluate(double* args, const std::map<std::string, double>& variables) const = 0;
class Constant;
class Variable;
......@@ -71,6 +102,9 @@ public:
class Asin;
class Acos;
class Atan;
class Square;
class Cube;
class Reciprocal;
};
class Operation::Constant : public Operation {
......@@ -94,6 +128,9 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return value;
}
double getValue() const {
return value;
}
private:
double value;
};
......@@ -526,6 +563,69 @@ public:
}
};
class Operation::Square : public Operation {
public:
Square() {
}
std::string getName() const {
return "square";
}
Id getId() const {
return SQUARE;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Square();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]*args[0];
}
};
class Operation::Cube : public Operation {
public:
Cube() {
}
std::string getName() const {
return "cube";
}
Id getId() const {
return CUBE;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Cube();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]*args[0]*args[0];
}
};
class Operation::Reciprocal : public Operation {
public:
Reciprocal() {
}
std::string getName() const {
return "recip";
}
Id getId() const {
return RECIPROCAL;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Reciprocal();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return 1.0/args[0];
}
};
} // namespace Lepton
#endif /*LEPTON_OPERATION_H_*/
......@@ -39,14 +39,51 @@
namespace Lepton {
/**
* This class represents the result of parsing an expression. It provides methods for working with the
* expression in various ways, such as evaluating it, getting the tree representation of the expresson, etc.
*/
class LEPTON_EXPORT ParsedExpression {
public:
/**
* Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class
* to parse expression.
*/
ParsedExpression(ExpressionTreeNode rootNode);
/**
* Get the root node of the expression's abstract syntax tree.
*/
const ExpressionTreeNode& getRootNode() const;
/**
* Evaluate the expression. If the expression involves any variables, this method will throw an exception.
*/
double evaluate() const;
/**
* Evaluate the expression.
*
* @param variables a map specifying the values of all variables that appear in the expression. If any
* variable appears in the expression but is not included in this map, an exception
* will be thrown.
*/
double evaluate(const std::map<std::string, double>& variables) const;
/**
* Create a new ParsedExpression which produces the same result as this one, but is faster to evaluate.
*/
ParsedExpression optimize() const;
/**
* Create a new ParsedExpression which produces the same result as this one, but is faster to evaluate.
*
* @param variables a map specifying values for a subset of variables that appear in the expression.
* All occurrences of these variables in the expression are replaced with the values
* specified.
*/
ParsedExpression optimize(const std::map<std::string, double>& variables) const;
private:
double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables) const;
static double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
ExpressionTreeNode rootNode;
};
......
......@@ -59,7 +59,7 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati
throw Exception("Parse error: wrong number of arguments to function");
}
ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.getOperation().clone()), children(node.getChildren()) {
ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(&node.getOperation() == NULL ? NULL : node.getOperation().clone()), children(node.getChildren()) {
}
ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) {
......
......@@ -44,16 +44,86 @@ const ExpressionTreeNode& ParsedExpression::getRootNode() const {
}
double ParsedExpression::evaluate() const {
return evaluate(rootNode, map<string, double>());
return evaluate(getRootNode(), map<string, double>());
}
double ParsedExpression::evaluate(const std::map<std::string, double>& variables) const {
return evaluate(rootNode, variables);
return evaluate(getRootNode(), variables);
}
double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<string, double>& variables) const {
double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<string, double>& variables) {
vector<double> args(node.getChildren().size());
for (int i = 0; i < args.size(); i++)
args[i] = evaluate(node.getChildren()[i], variables);
return node.getOperation().evaluate(&args[0], variables);
}
ParsedExpression ParsedExpression::optimize() const {
ParsedExpression result = precalculateConstantSubexpressions(getRootNode());
result = substituteSimplerExpression(result.getRootNode());
return result;
}
ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const {
ParsedExpression result = preevaluateVariables(getRootNode(), variables);
result = precalculateConstantSubexpressions(result.getRootNode());
result = substituteSimplerExpression(result.getRootNode());
return result;
}
ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNode& node, const map<string, double>& variables) {
if (node.getOperation().getId() == Operation::VARIABLE) {
const Operation::Variable& var = dynamic_cast<const Operation::Variable&>(node.getOperation());
map<string, double>::const_iterator iter = variables.find(var.getName());
if (iter == variables.end())
return node;
return ExpressionTreeNode(new Operation::Constant(iter->second));
}
vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < children.size(); i++)
children[i] = preevaluateVariables(node.getChildren()[i], variables);
return ExpressionTreeNode(node.getOperation().clone(), children);
}
ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) {
vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < children.size(); i++)
children[i] = precalculateConstantSubexpressions(node.getChildren()[i]);
ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children);
if (node.getOperation().getId() == Operation::VARIABLE)
return result;
for (int i = 0; i < children.size(); i++)
if (children[i].getOperation().getId() != Operation::CONSTANT)
return result;
return ExpressionTreeNode(new Operation::Constant(evaluate(result, map<string, double>())));
}
ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) {
vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < children.size(); i++)
children[i] = substituteSimplerExpression(node.getChildren()[i]);
switch (node.getOperation().getId()) {
case Operation::DIVIDE:
if (children[0].getOperation().getId() == Operation::CONSTANT) {
if (dynamic_cast<const Operation::Constant&>(children[0].getOperation()).getValue() == 1.0)
return ExpressionTreeNode(new Operation::Reciprocal(), children[1]);
}
break;
case Operation::POWER:
if (children[1].getOperation().getId() == Operation::CONSTANT) {
double exponent = dynamic_cast<const Operation::Constant&>(children[1].getOperation()).getValue();
if (exponent == 1.0)
return children[0];
if (exponent == -1.0)
return ExpressionTreeNode(new Operation::Reciprocal(), children[0]);
if (exponent == 2.0)
return ExpressionTreeNode(new Operation::Square(), children[0]);
if (exponent == 3.0)
return ExpressionTreeNode(new Operation::Cube(), children[0]);
if (exponent == 0.5)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
}
break;
}
return ExpressionTreeNode(node.getOperation().clone(), children);
}
......@@ -236,6 +236,9 @@ Operation* Parser::getFunctionOperation(const std::string& name, int arguments)
opMap["asin"] = Operation::ASIN;
opMap["acos"] = Operation::ACOS;
opMap["atan"] = Operation::ATAN;
opMap["square"] = Operation::SQUARE;
opMap["cube"] = Operation::CUBE;
opMap["recip"] = Operation::RECIPROCAL;
}
string trimmed = name.substr(0, name.size()-1);
map<string, Operation::Id>::const_iterator iter = opMap.find(trimmed);
......@@ -266,6 +269,12 @@ Operation* Parser::getFunctionOperation(const std::string& name, int arguments)
return new Operation::Acos();
case Operation::ATAN:
return new Operation::Atan();
case Operation::SQUARE:
return new Operation::Square();
case Operation::CUBE:
return new Operation::Cube();
case Operation::RECIPROCAL:
return new Operation::Reciprocal();
default:
throw Exception("Parse error: unknown function");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment