"vscode:/vscode.git/clone" did not exist on "5c2569b0b55bec609c7883e2e694c1d2de9dfb22"
Commit 119fb95d authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented custom function

parent b3f4c0f8
#ifndef LEPTON_CUSTOM_FUNCTION_H_
#define LEPTON_CUSTOM_FUNCTION_H_
/* -------------------------------------------------------------------------- *
* Lepton *
* -------------------------------------------------------------------------- *
* This is part of the Lepton expression parser originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "windowsIncludes.h"
namespace Lepton {
/**
* This class is the interface for defining your own function that may be included in expressions.
* To use it, create a concrete subclass that implements all of the virtual methods for each new function
* you want to define. Then when you call Parser::parse() to parse an expression, pass a map of
* function names to CustomFunction objects.
*/
class LEPTON_EXPORT CustomFunction {
public:
virtual ~CustomFunction() {
}
/**
* Get the number of arguments this function exprects.
*/
virtual int getNumArguments() const = 0;
/**
* Evaluate the function.
*
* @param arguments the array of argument values
*/
virtual double evaluate(const double* arguments) const = 0;
/**
* Evaluate a derivative of the function.
*
* @param arguments the array of argument values
* @param derivOrder an array specifying the number of times the function has been differentiated
* with respect to each of its arguments. For example, the array {0, 2} indicates
* a second derivative with respect to the second argument.
*/
virtual double evaluateDerivative(const double* arguments, const int* derivOrder) const = 0;
/**
* Create a new duplicate of this object on the heap using the "new" operator.
*/
virtual CustomFunction* clone() const = 0;
};
} // namespace Lepton
#endif /*LEPTON_CUSTOM_FUNCTION_H_*/
......@@ -33,6 +33,7 @@
* -------------------------------------------------------------------------- */
#include "windowsIncludes.h"
#include "CustomFunction.h"
#include <cmath>
#include <map>
#include <string>
......@@ -53,6 +54,8 @@ class ExpressionTreeNode;
class LEPTON_EXPORT Operation {
public:
virtual ~Operation() {
}
/**
* This enumeration lists all Operation subclasses. This is provided so that switch statements
* can be used when processing or analyzing parsed expressions.
......@@ -177,7 +180,13 @@ private:
class Operation::Custom : public Operation {
public:
Custom(const std::string& name, int arguments) : name(name), arguments(arguments) {
Custom(const std::string& name, CustomFunction* function) : name(name), function(function), isDerivative(false), derivOrder(function->getNumArguments(), 0) {
}
Custom(const Custom& base, int derivIndex) : name(base.name), function(base.function->clone()), isDerivative(true), derivOrder(base.derivOrder) {
derivOrder[derivIndex]++;
}
~Custom() {
delete function;
}
std::string getName() const {
return name;
......@@ -186,18 +195,25 @@ public:
return CUSTOM;
}
int getNumArguments() const {
return arguments;
return function->getNumArguments();
}
Operation* clone() const {
return new Custom(name, arguments);
Custom* clone = new Custom(name, function->clone());
clone->isDerivative = isDerivative;
clone->derivOrder = derivOrder;
return clone;
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return 0.0;
if (isDerivative)
return function->evaluateDerivative(args, &derivOrder[0]);
return function->evaluate(args);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
private:
std::string name;
int arguments;
CustomFunction* function;
bool isDerivative;
std::vector<int> derivOrder;
};
class Operation::Add : public Operation {
......
......@@ -33,11 +33,13 @@
* -------------------------------------------------------------------------- */
#include "windowsIncludes.h"
#include <map>
#include <string>
#include <vector>
namespace Lepton {
class CustomFunction;
class ExpressionTreeNode;
class Operation;
class ParsedExpression;
......@@ -49,13 +51,23 @@ class ParseToken;
class LEPTON_EXPORT Parser {
public:
static ParsedExpression parse(std::string expression);
/**
* Parse a mathematical expression and return a representation of it as an abstract syntax tree.
*/
static ParsedExpression parse(const std::string& expression);
/**
* Parse a mathematical expression and return a representation of it as an abstract syntax tree.
*
* @param customFunctions a map specifying user defined functions that may appear in the expression.
* The key are function names, and the values are corresponding CustomFunction objects.
*/
static ParsedExpression parse(const std::string& expression, const std::map<std::string, CustomFunction*>& customFunctions);
private:
static std::vector<ParseToken> tokenize(std::string expression);
static ParseToken getNextToken(std::string expression, int start);
static ExpressionTreeNode parsePrecedence(const std::vector<ParseToken>& tokens, int& pos, int precedence);
static ExpressionTreeNode parsePrecedence(const std::vector<ParseToken>& tokens, int& pos, const std::map<std::string, CustomFunction*>& customFunctions, int precedence);
static Operation* getOperatorOperation(const std::string& name);
static Operation* getFunctionOperation(const std::string& name, int arguments);
static Operation* getFunctionOperation(const std::string& name, const std::map<std::string, CustomFunction*>& customFunctions);
};
} // namespace Lepton
......
......@@ -47,7 +47,15 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector<Expressi
}
ExpressionTreeNode Operation::Custom::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (function->getNumArguments() == 0)
return ExpressionTreeNode(new Operation::Constant(0.0));
ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]);
for (int i = 1; i < getNumArguments(); i++) {
result = ExpressionTreeNode(new Operation::Add(),
result,
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]));
}
return result;
}
ExpressionTreeNode Operation::Add::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
......
......@@ -30,6 +30,7 @@
* -------------------------------------------------------------------------- */
#include "Parser.h"
#include "CustomFunction.h"
#include "Exception.h"
#include "ExpressionTreeNode.h"
#include "Operation.h"
......@@ -130,16 +131,20 @@ vector<ParseToken> Parser::tokenize(string expression) {
return tokens;
}
ParsedExpression Parser::parse(string expression) {
ParsedExpression Parser::parse(const string& expression) {
return parse(expression, map<string, CustomFunction*>());
}
ParsedExpression Parser::parse(const string& expression, const map<string, CustomFunction*>& customFunctions) {
vector<ParseToken> tokens = tokenize(expression);
int pos = 0;
ExpressionTreeNode result = parsePrecedence(tokens, pos, 0);
ExpressionTreeNode result = parsePrecedence(tokens, pos, customFunctions, 0);
if (pos != tokens.size())
throw Exception("Parse error: unexpected text at end of expression");
return ParsedExpression(result);
}
ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int& pos, int precedence) {
ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int& pos, const map<string, CustomFunction*>& customFunctions, int precedence) {
if (pos == tokens.size())
throw Exception("Parse error: unexpected end of expression");
......@@ -160,7 +165,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int
}
else if (token.getType() == ParseToken::LeftParen) {
pos++;
result = parsePrecedence(tokens, pos, 0);
result = parsePrecedence(tokens, pos, customFunctions, 0);
if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen)
throw Exception("Parse error: unbalanced parentheses");
pos++;
......@@ -170,7 +175,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int
vector<ExpressionTreeNode> args;
bool moreArgs;
do {
args.push_back(parsePrecedence(tokens, pos, 0));
args.push_back(parsePrecedence(tokens, pos, customFunctions, 0));
moreArgs = (pos < tokens.size() && tokens[pos].getType() == ParseToken::Comma);
if (moreArgs)
pos++;
......@@ -178,11 +183,11 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int
if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen)
throw Exception("Parse error: unbalanced parentheses");
pos++;
result = ExpressionTreeNode(getFunctionOperation(token.getText(), args.size()), args);
result = ExpressionTreeNode(getFunctionOperation(token.getText(), customFunctions), args);
}
else if (token.getType() == ParseToken::Operator && token.getText() == "-") {
pos++;
ExpressionTreeNode toNegate = parsePrecedence(tokens, pos, 2);
ExpressionTreeNode toNegate = parsePrecedence(tokens, pos, customFunctions, 2);
result = ExpressionTreeNode(new Operation::Negate(), toNegate);
}
else
......@@ -197,7 +202,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int
if (opPrecedence < precedence)
return result;
pos++;
ExpressionTreeNode arg = parsePrecedence(tokens, pos, LeftAssociative[op] ? opPrecedence+1 : opPrecedence);
ExpressionTreeNode arg = parsePrecedence(tokens, pos, customFunctions, LeftAssociative[op] ? opPrecedence+1 : opPrecedence);
result = ExpressionTreeNode(getOperatorOperation(token.getText()), result, arg);
}
return result;
......@@ -220,7 +225,7 @@ Operation* Parser::getOperatorOperation(const std::string& name) {
}
}
Operation* Parser::getFunctionOperation(const std::string& name, int arguments) {
Operation* Parser::getFunctionOperation(const std::string& name, const map<string, CustomFunction*>& customFunctions) {
static map<string, Operation::Id> opMap;
if (opMap.size() == 0) {
......@@ -243,9 +248,18 @@ Operation* Parser::getFunctionOperation(const std::string& name, int arguments)
opMap["decrement"] = Operation::DECREMENT;
}
string trimmed = name.substr(0, name.size()-1);
// First check custom functions.
map<string, CustomFunction*>::const_iterator custom = customFunctions.find(trimmed);
if (custom != customFunctions.end())
return new Operation::Custom(trimmed, custom->second->clone());
// Now try standard functions.
map<string, Operation::Id>::const_iterator iter = opMap.find(trimmed);
if (iter == opMap.end())
return new Operation::Custom(trimmed, arguments);
throw Exception("Parse error: unknown function");
switch (iter->second) {
case Operation::SQRT:
return new Operation::Sqrt();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment