Commit e5c065df authored by Peter Eastman's avatar Peter Eastman
Browse files

Lots of new optimizations in ParsedExpression.optimize(). Also optimized CUDA...

Lots of new optimizations in ParsedExpression.optimize().  Also optimized CUDA evaluation of expressions.
parent 41bc0b2d
...@@ -62,7 +62,7 @@ public: ...@@ -62,7 +62,7 @@ public:
* can be used when processing or analyzing parsed expressions. * can be used when processing or analyzing parsed expressions.
*/ */
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG, enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SQUARE, CUBE, RECIPROCAL, INCREMENT, DECREMENT}; SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SQUARE, CUBE, RECIPROCAL, ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT};
/** /**
* Get the name of this Operation. * Get the name of this Operation.
*/ */
...@@ -119,8 +119,9 @@ public: ...@@ -119,8 +119,9 @@ public:
class Square; class Square;
class Cube; class Cube;
class Reciprocal; class Reciprocal;
class Increment; class AddConstant;
class Decrement; class MultiplyConstant;
class PowerConstant;
}; };
class Operation::Constant : public Operation { class Operation::Constant : public Operation {
...@@ -679,48 +680,91 @@ public: ...@@ -679,48 +680,91 @@ public:
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Increment : public Operation { class Operation::AddConstant : public Operation {
public: public:
Increment() { AddConstant(double value) : value(value) {
} }
std::string getName() const { std::string getName() const {
return "increment"; std::stringstream name;
name << value << "+";
return name.str();
} }
Id getId() const { Id getId() const {
return INCREMENT; return ADD_CONSTANT;
} }
int getNumArguments() const { int getNumArguments() const {
return 1; return 1;
} }
Operation* clone() const { Operation* clone() const {
return new Increment(); return new AddConstant(value);
} }
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]+1.0; return args[0]+value;
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
double getValue() const {
return value;
}
private:
double value;
}; };
class Operation::Decrement : public Operation { class Operation::MultiplyConstant : public Operation {
public: public:
Decrement() { MultiplyConstant(double value) : value(value) {
} }
std::string getName() const { std::string getName() const {
return "decrement"; std::stringstream name;
name << value << "*";
return name.str();
} }
Id getId() const { Id getId() const {
return DECREMENT; return MULTIPLY_CONSTANT;
} }
int getNumArguments() const { int getNumArguments() const {
return 1; return 1;
} }
Operation* clone() const { Operation* clone() const {
return new Decrement(); return new MultiplyConstant(value);
} }
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]-1.0; return args[0]*value;
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const; ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
double getValue() const {
return value;
}
private:
double value;
};
class Operation::PowerConstant : public Operation {
public:
PowerConstant(double value) : value(value) {
}
std::string getName() const {
std::stringstream name;
name << "^" << value;
return name.str();
}
Id getId() const {
return POWER_CONSTANT;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new PowerConstant(value);
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::pow(args[0], value);
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
double getValue() const {
return value;
}
private:
double value;
}; };
} // namespace Lepton } // namespace Lepton
......
...@@ -86,7 +86,7 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionT ...@@ -86,7 +86,7 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionT
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(),
children[1], children[1],
ExpressionTreeNode(new Operation::Power(), ExpressionTreeNode(new Operation::Power(),
children[0], ExpressionTreeNode(new Operation::Decrement(), children[1]))), children[0], ExpressionTreeNode(new Operation::AddConstant(-1.0), children[1]))),
childDerivs[0]), childDerivs[0]),
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(),
...@@ -101,8 +101,7 @@ ExpressionTreeNode Operation::Negate::differentiate(const std::vector<Expression ...@@ -101,8 +101,7 @@ ExpressionTreeNode Operation::Negate::differentiate(const std::vector<Expression
ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(0.5),
ExpressionTreeNode(new Operation::Constant(0.5)),
ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Sqrt(), children[0]))), ExpressionTreeNode(new Operation::Sqrt(), children[0]))),
childDerivs[0]); childDerivs[0]);
...@@ -189,23 +188,21 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTr ...@@ -189,23 +188,21 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTr
ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Increment(), ExpressionTreeNode(new Operation::AddConstant(1.0),
ExpressionTreeNode(new Operation::Square(), children[0]))), ExpressionTreeNode(new Operation::Square(), children[0]))),
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(2.0),
ExpressionTreeNode(new Operation::Constant(2.0)),
children[0]), children[0]),
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(3.0),
ExpressionTreeNode(new Operation::Constant(3.0)),
ExpressionTreeNode(new Operation::Square(), children[0])), ExpressionTreeNode(new Operation::Square(), children[0])),
childDerivs[0]); childDerivs[0]);
} }
...@@ -218,10 +215,19 @@ ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector<Expres ...@@ -218,10 +215,19 @@ ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector<Expres
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::Increment::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return childDerivs[0]; return childDerivs[0];
} }
ExpressionTreeNode Operation::Decrement::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return childDerivs[0]; return ExpressionTreeNode(new Operation::MultiplyConstant(value),
childDerivs[0]);
}
ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(value),
ExpressionTreeNode(new Operation::PowerConstant(value-1),
children[0])),
childDerivs[0]));
} }
...@@ -63,13 +63,15 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri ...@@ -63,13 +63,15 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri
ParsedExpression ParsedExpression::optimize() const { ParsedExpression ParsedExpression::optimize() const {
ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode()); ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode());
result = substituteSimplerExpression(result); result = substituteSimplerExpression(result);
return result; result = substituteSimplerExpression(result);
return ParsedExpression(result);
} }
ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const { ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const {
ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables); ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables);
result = precalculateConstantSubexpressions(result); result = precalculateConstantSubexpressions(result);
result = substituteSimplerExpression(result); result = substituteSimplerExpression(result);
result = substituteSimplerExpression(result);
return ParsedExpression(result); return ParsedExpression(result);
} }
...@@ -111,12 +113,12 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -111,12 +113,12 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
double second = getConstantValue(children[1]); double second = getConstantValue(children[1]);
if (first == 0.0) // Add 0 if (first == 0.0) // Add 0
return children[1]; return children[1];
if (first == 1.0) // Add 1
return ExpressionTreeNode(new Operation::Increment(), children[1]);
if (second == 0.0) // Add 0 if (second == 0.0) // Add 0
return children[0]; return children[0];
if (second == 1.0) // Add 1 if (first == first) // Add a constant
return ExpressionTreeNode(new Operation::Increment(), children[0]); return ExpressionTreeNode(new Operation::AddConstant(first), children[1]);
if (second == second) // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]);
break; break;
} }
case Operation::SUBTRACT: case Operation::SUBTRACT:
...@@ -127,8 +129,8 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -127,8 +129,8 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
double second = getConstantValue(children[1]); double second = getConstantValue(children[1]);
if (second == 0.0) // Subtract 0 if (second == 0.0) // Subtract 0
return children[0]; return children[0];
if (second == 1.0) // Subtract 1 if (second == second) // Subtract a constant
return ExpressionTreeNode(new Operation::Decrement(), children[0]); return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]);
break; break;
} }
case Operation::MULTIPLY: case Operation::MULTIPLY:
...@@ -141,21 +143,15 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -141,21 +143,15 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
return children[1]; return children[1];
if (second == 1.0) // Multiply by 1 if (second == 1.0) // Multiply by 1
return children[0]; return children[0];
if (children[0].getOperation().getId() == Operation::CONSTANT) { if (children[0].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant
if (children[1].getOperation().getId() == Operation::MULTIPLY) { if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
if (children[1].getChildren()[0].getOperation().getId() == Operation::CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast<const Operation::MultiplyConstant*>(&children[1].getOperation())->getValue()), children[1].getChildren()[0]);
return ExpressionTreeNode(new Operation::Multiply(), children[1].getChildren()[1], ExpressionTreeNode(new Operation::Constant(getConstantValue(children[1].getChildren()[0])*first))); return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]);
if (children[1].getChildren()[1].getOperation().getId() == Operation::CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::Multiply(), children[1].getChildren()[0], ExpressionTreeNode(new Operation::Constant(getConstantValue(children[1].getChildren()[1])*first)));
}
}
if (children[1].getOperation().getId() == Operation::CONSTANT) {
if (children[0].getOperation().getId() == Operation::MULTIPLY) {
if (children[0].getChildren()[0].getOperation().getId() == Operation::CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::Multiply(), children[0].getChildren()[1], ExpressionTreeNode(new Operation::Constant(getConstantValue(children[0].getChildren()[0])*second)));
if (children[0].getChildren()[1].getOperation().getId() == Operation::CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::Multiply(), children[0].getChildren()[0], ExpressionTreeNode(new Operation::Constant(getConstantValue(children[0].getChildren()[1])*second)));
} }
if (children[1].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]);
} }
break; break;
} }
...@@ -170,13 +166,9 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -170,13 +166,9 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
if (denominator == 1.0) // Divide by 1 if (denominator == 1.0) // Divide by 1
return children[0]; return children[0];
if (children[1].getOperation().getId() == Operation::CONSTANT) { if (children[1].getOperation().getId() == Operation::CONSTANT) {
if (children[0].getOperation().getId() == Operation::MULTIPLY) { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply
if (children[0].getChildren()[0].getOperation().getId() == Operation::CONSTANT) // Combine a multiply and a divide into one multiply return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/denominator), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::Multiply(), children[0].getChildren()[1], ExpressionTreeNode(new Operation::Constant(getConstantValue(children[0].getChildren()[0])/denominator))); return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/denominator), children[0]); // Replace a divide with a multiply
if (children[0].getChildren()[1].getOperation().getId() == Operation::CONSTANT) // Combine a multiply and a divide into one multiply
return ExpressionTreeNode(new Operation::Multiply(), children[0].getChildren()[0], ExpressionTreeNode(new Operation::Constant(getConstantValue(children[0].getChildren()[1])/denominator)));
}
return ExpressionTreeNode(new Operation::Multiply(), children[0], ExpressionTreeNode(new Operation::Constant(1.0/denominator))); // Replace a divide with a multiply
} }
break; break;
} }
...@@ -200,6 +192,14 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -200,6 +192,14 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
return ExpressionTreeNode(new Operation::Cube(), children[0]); return ExpressionTreeNode(new Operation::Cube(), children[0]);
if (exponent == 0.5) // x^0.5 = sqrt(x) if (exponent == 0.5) // x^0.5 = sqrt(x)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]); return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
if (exponent == exponent) // Constant power
return ExpressionTreeNode(new Operation::PowerConstant(exponent), children[0]);
break;
}
case Operation::NEGATE:
{
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
break; break;
} }
} }
......
...@@ -244,8 +244,6 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -244,8 +244,6 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
opMap["square"] = Operation::SQUARE; opMap["square"] = Operation::SQUARE;
opMap["cube"] = Operation::CUBE; opMap["cube"] = Operation::CUBE;
opMap["recip"] = Operation::RECIPROCAL; opMap["recip"] = Operation::RECIPROCAL;
opMap["increment"] = Operation::INCREMENT;
opMap["decrement"] = Operation::DECREMENT;
} }
string trimmed = name.substr(0, name.size()-1); string trimmed = name.substr(0, name.size()-1);
...@@ -291,10 +289,6 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin ...@@ -291,10 +289,6 @@ Operation* Parser::getFunctionOperation(const std::string& name, const map<strin
return new Operation::Cube(); return new Operation::Cube();
case Operation::RECIPROCAL: case Operation::RECIPROCAL:
return new Operation::Reciprocal(); return new Operation::Reciprocal();
case Operation::INCREMENT:
return new Operation::Increment();
case Operation::DECREMENT:
return new Operation::Decrement();
default: default:
throw Exception("Parse error: unknown function"); throw Exception("Parse error: unknown function");
} }
......
...@@ -245,7 +245,7 @@ enum CudaNonbondedMethod ...@@ -245,7 +245,7 @@ enum CudaNonbondedMethod
enum ExpressionOp { enum ExpressionOp {
CONSTANT = 0, VARIABLE0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, VARIABLE8, GLOBAL, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, CONSTANT = 0, VARIABLE0, VARIABLE1, VARIABLE2, VARIABLE3, VARIABLE4, VARIABLE5, VARIABLE6, VARIABLE7, VARIABLE8, GLOBAL, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE,
POWER, NEGATE, SQRT, EXP, LOG, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SQUARE, CUBE, RECIPROCAL, INCREMENT, DECREMENT POWER, NEGATE, SQRT, EXP, LOG, SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SQUARE, CUBE, RECIPROCAL, ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT
}; };
template<int SIZE> template<int SIZE>
......
...@@ -146,7 +146,7 @@ static Expression<SIZE> createExpression(const string& expression, const Lepton: ...@@ -146,7 +146,7 @@ static Expression<SIZE> createExpression(const string& expression, const Lepton:
switch (op.getId()) { switch (op.getId()) {
case Operation::CONSTANT: case Operation::CONSTANT:
exp.op[i] = CONSTANT; exp.op[i] = CONSTANT;
exp.arg[i] = op.evaluate(NULL, map<string, double>()); exp.arg[i] = dynamic_cast<const Operation::Constant*>(&op)->getValue();
break; break;
case Operation::VARIABLE: case Operation::VARIABLE:
if (variables.size() > 0 && op.getName() == variables[0]) if (variables.size() > 0 && op.getName() == variables[0])
...@@ -239,11 +239,17 @@ static Expression<SIZE> createExpression(const string& expression, const Lepton: ...@@ -239,11 +239,17 @@ static Expression<SIZE> createExpression(const string& expression, const Lepton:
case Operation::RECIPROCAL: case Operation::RECIPROCAL:
exp.op[i] = RECIPROCAL; exp.op[i] = RECIPROCAL;
break; break;
case Operation::INCREMENT: case Operation::ADD_CONSTANT:
exp.op[i] = INCREMENT; exp.op[i] = ADD_CONSTANT;
exp.arg[i] = dynamic_cast<const Operation::AddConstant*>(&op)->getValue();
break; break;
case Operation::DECREMENT: case Operation::MULTIPLY_CONSTANT:
exp.op[i] = DECREMENT; exp.op[i] = MULTIPLY_CONSTANT;
exp.arg[i] = dynamic_cast<const Operation::MultiplyConstant*>(&op)->getValue();
break;
case Operation::POWER_CONSTANT:
exp.op[i] = POWER_CONSTANT;
exp.arg[i] = dynamic_cast<const Operation::PowerConstant*>(&op)->getValue();
break; break;
} }
} }
......
...@@ -105,132 +105,276 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float* ...@@ -105,132 +105,276 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
int stackPointer = -1; int stackPointer = -1;
for (int i = 0; i < expression->length; i++) for (int i = 0; i < expression->length; i++)
{ {
switch (expression->op[i]) int op = expression->op[i];
{ if (op < SQRT) {
case CONSTANT: if (op < VARIABLE8) {
if (op < VARIABLE4) {
if (op == CONSTANT) {
STACK(++stackPointer) = expression->arg[i]; STACK(++stackPointer) = expression->arg[i];
break; }
case VARIABLE0: else if (op == VARIABLE0) {
STACK(++stackPointer) = var0; STACK(++stackPointer) = var0;
break; }
case VARIABLE1: else if (op == VARIABLE1) {
STACK(++stackPointer) = vars1.x; STACK(++stackPointer) = vars1.x;
break; }
case VARIABLE2: else if (op == VARIABLE2) {
STACK(++stackPointer) = vars1.y; STACK(++stackPointer) = vars1.y;
break; }
case VARIABLE3: else if (op == VARIABLE3) {
STACK(++stackPointer) = vars1.z; STACK(++stackPointer) = vars1.z;
break; }
case VARIABLE4: }
else {
if (op == VARIABLE4) {
STACK(++stackPointer) = vars1.w; STACK(++stackPointer) = vars1.w;
break; }
case VARIABLE5: else if (op == VARIABLE5) {
STACK(++stackPointer) = vars2.x; STACK(++stackPointer) = vars2.x;
break; }
case VARIABLE6: else if (op == VARIABLE6) {
STACK(++stackPointer) = vars2.y; STACK(++stackPointer) = vars2.y;
break; }
case VARIABLE7: else if (op == VARIABLE7) {
STACK(++stackPointer) = vars2.z; STACK(++stackPointer) = vars2.z;
break; }
case VARIABLE8: }
}
else {
if (op < MULTIPLY) {
if (op == VARIABLE8) {
STACK(++stackPointer) = vars2.w; STACK(++stackPointer) = vars2.w;
break; }
case GLOBAL: else if (op == GLOBAL) {
STACK(++stackPointer) = globalParams[(int) expression->arg[i]]; STACK(++stackPointer) = globalParams[(int) expression->arg[i]];
break; }
case ADD: else if (op == ADD) {
{
float temp = STACK(stackPointer); float temp = STACK(stackPointer);
STACK(--stackPointer) += temp; STACK(--stackPointer) += temp;
break;
} }
case SUBTRACT: else if (op == SUBTRACT) {
{
float temp = STACK(stackPointer); float temp = STACK(stackPointer);
STACK(stackPointer) = temp-STACK(--stackPointer); STACK(stackPointer) = temp-STACK(--stackPointer);
break;
} }
case MULTIPLY: }
{ else {
if (op == MULTIPLY) {
float temp = STACK(stackPointer); float temp = STACK(stackPointer);
STACK(--stackPointer) *= temp; STACK(--stackPointer) *= temp;
break;
} }
case DIVIDE: else if (op == DIVIDE) {
{
float temp = STACK(stackPointer); float temp = STACK(stackPointer);
STACK(stackPointer) = temp/STACK(--stackPointer); STACK(stackPointer) = temp/STACK(--stackPointer);
break;
} }
case POWER: else if (op == POWER) {
{
float temp = STACK(stackPointer); float temp = STACK(stackPointer);
STACK(stackPointer) = pow(temp, STACK(--stackPointer)); STACK(stackPointer) = pow(temp, STACK(--stackPointer));
break;
} }
case NEGATE: else if (op == NEGATE) {
STACK(stackPointer) *= -1.0f; STACK(stackPointer) *= -1.0f;
break; }
case SQRT: }
}
}
else {
if (op < ASIN) {
if (op < SEC) {
if (op == SQRT) {
STACK(stackPointer) = sqrt(STACK(stackPointer)); STACK(stackPointer) = sqrt(STACK(stackPointer));
break; }
case EXP: else if (op == EXP) {
STACK(stackPointer) = exp(STACK(stackPointer)); STACK(stackPointer) = exp(STACK(stackPointer));
break; }
case LOG: else if (op == LOG) {
STACK(stackPointer) = log(STACK(stackPointer)); STACK(stackPointer) = log(STACK(stackPointer));
break; }
case SIN: else if (op == SIN) {
STACK(stackPointer) = sin(STACK(stackPointer)); STACK(stackPointer) = sin(STACK(stackPointer));
break; }
case COS: else if (op == COS) {
STACK(stackPointer) = cos(STACK(stackPointer)); STACK(stackPointer) = cos(STACK(stackPointer));
break; }
case SEC: }
else {
if (op == SEC) {
STACK(stackPointer) = 1.0f/cos(STACK(stackPointer)); STACK(stackPointer) = 1.0f/cos(STACK(stackPointer));
break; }
case CSC: else if (op == CSC) {
STACK(stackPointer) = 1.0f/sin(STACK(stackPointer)); STACK(stackPointer) = 1.0f/sin(STACK(stackPointer));
break; }
case TAN: else if (op == TAN) {
STACK(stackPointer) = tan(STACK(stackPointer)); STACK(stackPointer) = tan(STACK(stackPointer));
break; }
case COT: else if (op == COT) {
STACK(stackPointer) = 1.0f/tan(STACK(stackPointer)); STACK(stackPointer) = 1.0f/tan(STACK(stackPointer));
break; }
case ASIN: }
}
else {
if (op < RECIPROCAL) {
if (op == ASIN) {
STACK(stackPointer) = asin(STACK(stackPointer)); STACK(stackPointer) = asin(STACK(stackPointer));
break; }
case ACOS: else if (op == ACOS) {
STACK(stackPointer) = acos(STACK(stackPointer)); STACK(stackPointer) = acos(STACK(stackPointer));
break; }
case ATAN: else if (op == ATAN) {
STACK(stackPointer) = atan(STACK(stackPointer)); STACK(stackPointer) = atan(STACK(stackPointer));
break; }
case SQUARE: else if (op == SQUARE) {
{
float temp = STACK(stackPointer); float temp = STACK(stackPointer);
STACK(stackPointer) *= temp; STACK(stackPointer) *= temp;
break;
} }
case CUBE: else if (op == CUBE) {
{
float temp = STACK(stackPointer); float temp = STACK(stackPointer);
STACK(stackPointer) *= temp*temp; STACK(stackPointer) *= temp*temp;
break;
} }
case RECIPROCAL: }
else {
if (op == RECIPROCAL) {
STACK(stackPointer) = 1.0f/STACK(stackPointer); STACK(stackPointer) = 1.0f/STACK(stackPointer);
break;
case INCREMENT:
STACK(stackPointer) += 1.0f;
break;
case DECREMENT:
STACK(stackPointer) -= 1.0f;
break;
} }
else if (op == ADD_CONSTANT) {
STACK(stackPointer) += expression->arg[i];
}
else if (op == MULTIPLY_CONSTANT) {
STACK(stackPointer) *= expression->arg[i];
}
else if (op == POWER_CONSTANT) {
STACK(stackPointer) = pow(STACK(stackPointer), expression->arg[i]);
}
}
}
}
// switch (expression->op[i])
// {
// case CONSTANT:
// STACK(++stackPointer) = expression->arg[i];
// break;
// case VARIABLE0:
// STACK(++stackPointer) = var0;
// break;
// case VARIABLE1:
// STACK(++stackPointer) = vars1.x;
// break;
// case VARIABLE2:
// STACK(++stackPointer) = vars1.y;
// break;
// case VARIABLE3:
// STACK(++stackPointer) = vars1.z;
// break;
// case VARIABLE4:
// STACK(++stackPointer) = vars1.w;
// break;
// case VARIABLE5:
// STACK(++stackPointer) = vars2.x;
// break;
// case VARIABLE6:
// STACK(++stackPointer) = vars2.y;
// break;
// case VARIABLE7:
// STACK(++stackPointer) = vars2.z;
// break;
// case VARIABLE8:
// STACK(++stackPointer) = vars2.w;
// break;
// case GLOBAL:
// STACK(++stackPointer) = globalParams[(int) expression->arg[i]];
// break;
// case ADD:
// {
// float temp = STACK(stackPointer);
// STACK(--stackPointer) += temp;
// break;
// }
// case SUBTRACT:
// {
// float temp = STACK(stackPointer);
// STACK(stackPointer) = temp-STACK(--stackPointer);
// break;
// }
// case MULTIPLY:
// {
// float temp = STACK(stackPointer);
// STACK(--stackPointer) *= temp;
// break;
// }
// case DIVIDE:
// {
// float temp = STACK(stackPointer);
// STACK(stackPointer) = temp/STACK(--stackPointer);
// break;
// }
// case POWER:
// {
// float temp = STACK(stackPointer);
// STACK(stackPointer) = pow(temp, STACK(--stackPointer));
// break;
// }
// case NEGATE:
// STACK(stackPointer) *= -1.0f;
// break;
// case SQRT:
// STACK(stackPointer) = sqrt(STACK(stackPointer));
// break;
// case EXP:
// STACK(stackPointer) = exp(STACK(stackPointer));
// break;
// case LOG:
// STACK(stackPointer) = log(STACK(stackPointer));
// break;
// case SIN:
// STACK(stackPointer) = sin(STACK(stackPointer));
// break;
// case COS:
// STACK(stackPointer) = cos(STACK(stackPointer));
// break;
// case SEC:
// STACK(stackPointer) = 1.0f/cos(STACK(stackPointer));
// break;
// case CSC:
// STACK(stackPointer) = 1.0f/sin(STACK(stackPointer));
// break;
// case TAN:
// STACK(stackPointer) = tan(STACK(stackPointer));
// break;
// case COT:
// STACK(stackPointer) = 1.0f/tan(STACK(stackPointer));
// break;
// case ASIN:
// STACK(stackPointer) = asin(STACK(stackPointer));
// break;
// case ACOS:
// STACK(stackPointer) = acos(STACK(stackPointer));
// break;
// case ATAN:
// STACK(stackPointer) = atan(STACK(stackPointer));
// break;
// case SQUARE:
// {
// float temp = STACK(stackPointer);
// STACK(stackPointer) *= temp;
// break;
// }
// case CUBE:
// {
// float temp = STACK(stackPointer);
// STACK(stackPointer) *= temp*temp;
// break;
// }
// case RECIPROCAL:
// STACK(stackPointer) = 1.0f/STACK(stackPointer);
// break;
// case ADD_CONSTANT:
// STACK(stackPointer) += expression->arg[i];
// break;
// case MULTIPLY_CONSTANT:
// STACK(stackPointer) *= expression->arg[i];
// break;
// case POWER_CONSTANT:
// STACK(stackPointer) = pow(STACK(stackPointer), expression->arg[i]);
// break;
// }
} }
return STACK(stackPointer); return STACK(stackPointer);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment