"ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "4cfeb429535a521a276ec213b9d3412ea6075067"
Commit 119fb95d authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented custom function

parent b3f4c0f8
#ifndef LEPTON_CUSTOM_FUNCTION_H_
#define LEPTON_CUSTOM_FUNCTION_H_
/* -------------------------------------------------------------------------- *
* Lepton *
* -------------------------------------------------------------------------- *
* This is part of the Lepton expression parser originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "windowsIncludes.h"
namespace Lepton {
/**
* This class is the interface for defining your own function that may be included in expressions.
* To use it, create a concrete subclass that implements all of the virtual methods for each new function
* you want to define. Then when you call Parser::parse() to parse an expression, pass a map of
* function names to CustomFunction objects.
*/
class LEPTON_EXPORT CustomFunction {
public:
virtual ~CustomFunction() {
}
/**
* Get the number of arguments this function exprects.
*/
virtual int getNumArguments() const = 0;
/**
* Evaluate the function.
*
* @param arguments the array of argument values
*/
virtual double evaluate(const double* arguments) const = 0;
/**
* Evaluate a derivative of the function.
*
* @param arguments the array of argument values
* @param derivOrder an array specifying the number of times the function has been differentiated
* with respect to each of its arguments. For example, the array {0, 2} indicates
* a second derivative with respect to the second argument.
*/
virtual double evaluateDerivative(const double* arguments, const int* derivOrder) const = 0;
/**
* Create a new duplicate of this object on the heap using the "new" operator.
*/
virtual CustomFunction* clone() const = 0;
};
} // namespace Lepton
#endif /*LEPTON_CUSTOM_FUNCTION_H_*/
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "windowsIncludes.h" #include "windowsIncludes.h"
#include "CustomFunction.h"
#include <cmath> #include <cmath>
#include <map> #include <map>
#include <string> #include <string>
...@@ -53,6 +54,8 @@ class ExpressionTreeNode; ...@@ -53,6 +54,8 @@ class ExpressionTreeNode;
class LEPTON_EXPORT Operation { class LEPTON_EXPORT Operation {
public: public:
virtual ~Operation() {
}
/** /**
* This enumeration lists all Operation subclasses. This is provided so that switch statements * This enumeration lists all Operation subclasses. This is provided so that switch statements
* can be used when processing or analyzing parsed expressions. * can be used when processing or analyzing parsed expressions.
...@@ -177,7 +180,13 @@ private: ...@@ -177,7 +180,13 @@ private:
class Operation::Custom : public Operation { class Operation::Custom : public Operation {
public: public:
Custom(const std::string& name, int arguments) : name(name), arguments(arguments) { Custom(const std::string& name, CustomFunction* function) : name(name), function(function), isDerivative(false), derivOrder(function->getNumArguments(), 0) {
}
Custom(const Custom& base, int derivIndex) : name(base.name), function(base.function->clone()), isDerivative(true), derivOrder(base.derivOrder) {
derivOrder[derivIndex]++;
}
~Custom() {
delete function;
} }
std::string getName() const { std::string getName() const {
return name; return name;
...@@ -186,18 +195,25 @@ public: ...@@ -186,18 +195,25 @@ public:
return CUSTOM; return CUSTOM;
} }
int getNumArguments() const { int getNumArguments() const {
return arguments; return function->getNumArguments();
} }
Operation* clone() const { Operation* clone() const {
return new Custom(name, arguments); Custom* clone = new Custom(name, function->clone());
clone->isDerivative = isDerivative;
clone->derivOrder = derivOrder;
return clone;
} }
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return 0.0; if (isDerivative)
return function->evaluateDerivative(args, &derivOrder[0]);
return function->evaluate(args);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
private: private:
std::string name; std::string name;
int arguments; CustomFunction* function;
bool isDerivative;
std::vector<int> derivOrder;
}; };
class Operation::Add : public Operation { class Operation::Add : public Operation {
......
...@@ -33,11 +33,13 @@ ...@@ -33,11 +33,13 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "windowsIncludes.h" #include "windowsIncludes.h"
#include <map>
#include <string> #include <string>
#include <vector> #include <vector>
namespace Lepton { namespace Lepton {
class CustomFunction;
class ExpressionTreeNode; class ExpressionTreeNode;
class Operation; class Operation;
class ParsedExpression; class ParsedExpression;
...@@ -49,13 +51,23 @@ class ParseToken; ...@@ -49,13 +51,23 @@ class ParseToken;
class LEPTON_EXPORT Parser { class LEPTON_EXPORT Parser {
public: public:
static ParsedExpression parse(std::string expression); /**
* Parse a mathematical expression and return a representation of it as an abstract syntax tree.
*/
static ParsedExpression parse(const std::string& expression);
/**
* Parse a mathematical expression and return a representation of it as an abstract syntax tree.
*
* @param customFunctions a map specifying user defined functions that may appear in the expression.
* The key are function names, and the values are corresponding CustomFunction objects.
*/
static ParsedExpression parse(const std::string& expression, const std::map<std::string, CustomFunction*>& customFunctions);
private: private:
static std::vector<ParseToken> tokenize(std::string expression); static std::vector<ParseToken> tokenize(std::string expression);
static ParseToken getNextToken(std::string expression, int start); static ParseToken getNextToken(std::string expression, int start);
static ExpressionTreeNode parsePrecedence(const std::vector<ParseToken>& tokens, int& pos, int precedence); static ExpressionTreeNode parsePrecedence(const std::vector<ParseToken>& tokens, int& pos, const std::map<std::string, CustomFunction*>& customFunctions, int precedence);
static Operation* getOperatorOperation(const std::string& name); static Operation* getOperatorOperation(const std::string& name);
static Operation* getFunctionOperation(const std::string& name, int arguments); static Operation* getFunctionOperation(const std::string& name, const std::map<std::string, CustomFunction*>& customFunctions);
}; };
} // namespace Lepton } // namespace Lepton
......
...@@ -47,7 +47,15 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector<Expressi ...@@ -47,7 +47,15 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector<Expressi
} }
ExpressionTreeNode Operation::Custom::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Custom::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0)); if (function->getNumArguments() == 0)
return ExpressionTreeNode(new Operation::Constant(0.0));
ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]);
for (int i = 1; i < getNumArguments(); i++) {
result = ExpressionTreeNode(new Operation::Add(),
result,
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]));
}
return result;
} }
ExpressionTreeNode Operation::Add::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Add::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "Parser.h" #include "Parser.h"
#include "CustomFunction.h"
#include "Exception.h" #include "Exception.h"
#include "ExpressionTreeNode.h" #include "ExpressionTreeNode.h"
#include "Operation.h" #include "Operation.h"
...@@ -130,16 +131,20 @@ vector<ParseToken> Parser::tokenize(string expression) { ...@@ -130,16 +131,20 @@ vector<ParseToken> Parser::tokenize(string expression) {
return tokens; return tokens;
} }
ParsedExpression Parser::parse(string expression) { ParsedExpression Parser::parse(const string& expression) {
return parse(expression, map<string, CustomFunction*>());
}
ParsedExpression Parser::parse(const string& expression, const map<string, CustomFunction*>& customFunctions) {
vector<ParseToken> tokens = tokenize(expression); vector<ParseToken> tokens = tokenize(expression);
int pos = 0; int pos = 0;
ExpressionTreeNode result = parsePrecedence(tokens, pos, 0); ExpressionTreeNode result = parsePrecedence(tokens, pos, customFunctions, 0);
if (pos != tokens.size()) if (pos != tokens.size())
throw Exception("Parse error: unexpected text at end of expression"); throw Exception("Parse error: unexpected text at end of expression");
return ParsedExpression(result); return ParsedExpression(result);
} }
ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int& pos, int precedence) { ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int& pos, const map<string, CustomFunction*>& customFunctions, int precedence) {
if (pos == tokens.size()) if (pos == tokens.size())
throw Exception("Parse error: unexpected end of expression"); throw Exception("Parse error: unexpected end of expression");
...@@ -160,7 +165,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int ...@@ -160,7 +165,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int
} }
else if (token.getType() == ParseToken::LeftParen) { else if (token.getType() == ParseToken::LeftParen) {
pos++; pos++;
result = parsePrecedence(tokens, pos, 0); result = parsePrecedence(tokens, pos, customFunctions, 0);
if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen) if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen)
throw Exception("Parse error: unbalanced parentheses"); throw Exception("Parse error: unbalanced parentheses");
pos++; pos++;
...@@ -170,7 +175,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int ...@@ -170,7 +175,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int
vector<ExpressionTreeNode> args; vector<ExpressionTreeNode> args;
bool moreArgs; bool moreArgs;
do { do {
args.push_back(parsePrecedence(tokens, pos, 0)); args.push_back(parsePrecedence(tokens, pos, customFunctions, 0));
moreArgs = (pos < tokens.size() && tokens[pos].getType() == ParseToken::Comma); moreArgs = (pos < tokens.size() && tokens[pos].getType() == ParseToken::Comma);
if (moreArgs) if (moreArgs)
pos++; pos++;
...@@ -178,11 +183,11 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int ...@@ -178,11 +183,11 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int
if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen) if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen)
throw Exception("Parse error: unbalanced parentheses"); throw Exception("Parse error: unbalanced parentheses");
pos++; pos++;
result = ExpressionTreeNode(getFunctionOperation(token.getText(), args.size()), args); result = ExpressionTreeNode(getFunctionOperation(token.getText(), customFunctions), args);
} }
else if (token.getType() == ParseToken::Operator && token.getText() == "-") { else if (token.getType() == ParseToken::Operator && token.getText() == "-") {
pos++; pos++;
ExpressionTreeNode toNegate = parsePrecedence(tokens, pos, 2); ExpressionTreeNode toNegate = parsePrecedence(tokens, pos, customFunctions, 2);
result = ExpressionTreeNode(new Operation::Negate(), toNegate); result = ExpressionTreeNode(new Operation::Negate(), toNegate);
} }
else else
...@@ -197,7 +202,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int ...@@ -197,7 +202,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int
if (opPrecedence < precedence) if (opPrecedence < precedence)
return result; return result;
pos++; pos++;
ExpressionTreeNode arg = parsePrecedence(tokens, pos, LeftAssociative[op] ? opPrecedence+1 : opPrecedence); ExpressionTreeNode arg = parsePrecedence(tokens, pos, customFunctions, LeftAssociative[op] ? opPrecedence+1 : opPrecedence);
result = ExpressionTreeNode(getOperatorOperation(token.getText()), result, arg); result = ExpressionTreeNode(getOperatorOperation(token.getText()), result, arg);
} }
return result; return result;
...@@ -220,7 +225,7 @@ Operation* Parser::getOperatorOperation(const std::string& name) { ...@@ -220,7 +225,7 @@ Operation* Parser::getOperatorOperation(const std::string& name) {
} }
} }
Operation* Parser::getFunctionOperation(const std::string& name, int arguments) { Operation* Parser::getFunctionOperation(const std::string& name, const map<string, CustomFunction*>& customFunctions) {
static map<string, Operation::Id> opMap; static map<string, Operation::Id> opMap;
if (opMap.size() == 0) { if (opMap.size() == 0) {
...@@ -243,9 +248,18 @@ Operation* Parser::getFunctionOperation(const std::string& name, int arguments) ...@@ -243,9 +248,18 @@ Operation* Parser::getFunctionOperation(const std::string& name, int arguments)
opMap["decrement"] = Operation::DECREMENT; opMap["decrement"] = Operation::DECREMENT;
} }
string trimmed = name.substr(0, name.size()-1); string trimmed = name.substr(0, name.size()-1);
// First check custom functions.
map<string, CustomFunction*>::const_iterator custom = customFunctions.find(trimmed);
if (custom != customFunctions.end())
return new Operation::Custom(trimmed, custom->second->clone());
// Now try standard functions.
map<string, Operation::Id>::const_iterator iter = opMap.find(trimmed); map<string, Operation::Id>::const_iterator iter = opMap.find(trimmed);
if (iter == opMap.end()) if (iter == opMap.end())
return new Operation::Custom(trimmed, arguments); throw Exception("Parse error: unknown function");
switch (iter->second) { switch (iter->second) {
case Operation::SQRT: case Operation::SQRT:
return new Operation::Sqrt(); return new Operation::Sqrt();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment