Commit 4a25dc79 authored by peastman's avatar peastman
Browse files

Added more optimized operations to JIT compilation

parent a6fe70d2
...@@ -80,6 +80,7 @@ private: ...@@ -80,6 +80,7 @@ private:
CompiledExpression(const ParsedExpression& expression); CompiledExpression(const ParsedExpression& expression);
void compileExpression(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps); void compileExpression(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
void generateJitCode(); void generateJitCode();
void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86XmmVar& dest, asmjit::X86XmmVar& arg, double (*function)(double));
int findTempIndex(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps); int findTempIndex(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
std::vector<std::vector<int> > arguments; std::vector<std::vector<int> > arguments;
std::vector<int> target; std::vector<int> target;
......
...@@ -195,6 +195,10 @@ void CompiledExpression::generateJitCode() { ...@@ -195,6 +195,10 @@ void CompiledExpression::generateJitCode() {
value = dynamic_cast<Operation::MultiplyConstant&>(op).getValue(); value = dynamic_cast<Operation::MultiplyConstant&>(op).getValue();
else if (op.getId() == Operation::RECIPROCAL) else if (op.getId() == Operation::RECIPROCAL)
value = 1.0; value = 1.0;
else if (op.getId() == Operation::STEP)
value = 1.0;
else if (op.getId() == Operation::DELTA)
value = 1.0;
else else
continue; continue;
...@@ -232,55 +236,106 @@ void CompiledExpression::generateJitCode() { ...@@ -232,55 +236,106 @@ void CompiledExpression::generateJitCode() {
for (int i = 1; i < op.getNumArguments(); i++) for (int i = 1; i < op.getNumArguments(); i++)
args.push_back(args[0]+i); args.push_back(args[0]+i);
} }
// Generate instructions to execute this operation.
switch (op.getId()) { switch (op.getId()) {
case Operation::CONSTANT: case Operation::CONSTANT:
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::ADD: case Operation::ADD:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]); c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]);
break; break;
case Operation::SUBTRACT: case Operation::SUBTRACT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]); c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]);
break; break;
case Operation::MULTIPLY: case Operation::MULTIPLY:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]); c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]);
break; break;
case Operation::DIVIDE: case Operation::DIVIDE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]); c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]);
break; break;
case Operation::NEGATE: case Operation::NEGATE:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::SQRT: case Operation::SQRT:
c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
break;
case Operation::LOG:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log);
break;
case Operation::SIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin);
break;
case Operation::COS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos);
break;
case Operation::TAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan);
break;
case Operation::ASIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin);
break;
case Operation::ACOS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos);
break;
case Operation::ATAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan);
break;
case Operation::SINH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh);
break;
case Operation::COSH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh);
break;
case Operation::TANH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
break;
case Operation::STEP:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::DELTA:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::SQUARE: case Operation::SQUARE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::CUBE: case Operation::CUBE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::RECIPROCAL: case Operation::RECIPROCAL:
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::ADD_CONSTANT: case Operation::ADD_CONSTANT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::MULTIPLY_CONSTANT: case Operation::MULTIPLY_CONSTANT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::ABS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
break; break;
default: default:
// Just invoke evaluateOperation().
for (int i = 0; i < (int) args.size(); i++) for (int i = 0; i < (int) args.size(); i++)
c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
X86GpVar fn(c, kVarTypeIntPtr); X86GpVar fn(c, kVarTypeIntPtr);
...@@ -295,3 +350,11 @@ void CompiledExpression::generateJitCode() { ...@@ -295,3 +350,11 @@ void CompiledExpression::generateJitCode() {
c.endFunc(); c.endFunc();
jitCode = c.make(); jitCode = c.make();
} }
void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86XmmVar& dest, X86XmmVar& arg, double (*function)(double)) {
X86GpVar fn(c, kVarTypeIntPtr);
c.mov(fn, imm_ptr((void*) function));
X86CallNode* call = c.call(fn, kFuncConvHost, FuncBuilder1<double, double>());
call->setArg(0, arg);
call->setRet(0, dest);
}
...@@ -127,6 +127,18 @@ void verifyInvalidExpression(const string& expression) { ...@@ -127,6 +127,18 @@ void verifyInvalidExpression(const string& expression) {
throw exception(); throw exception();
} }
/**
* Verify that two numbers have the same value.
*/
void assertNumbersEqual(double val1, double val2) {
const double inf = numeric_limits<double>::infinity();
if (val1 == val1 || val2 == val2) // If both are NaN, that's fine.
if (val1 != inf || val2 != inf) // Both infinity is also fine.
if (val1 != -inf || val2 != -inf) // Same for -infinity.
ASSERT_EQUAL_TOL(val1, val2, 1e-10);
}
/** /**
* Verify that two expressions give the same value. * Verify that two expressions give the same value.
*/ */
...@@ -137,11 +149,22 @@ void verifySameValue(const ParsedExpression& exp1, const ParsedExpression& exp2, ...@@ -137,11 +149,22 @@ void verifySameValue(const ParsedExpression& exp1, const ParsedExpression& exp2,
variables["y"] = y; variables["y"] = y;
double val1 = exp1.evaluate(variables); double val1 = exp1.evaluate(variables);
double val2 = exp2.evaluate(variables); double val2 = exp2.evaluate(variables);
const double inf = numeric_limits<double>::infinity(); assertNumbersEqual(val1, val2);
if (val1 == val1 || val2 == val2) // If both are NaN, that's fine.
if (val1 != inf || val2 != inf) // Both infinity is also fine. // Now create CompiledExpressions from them and see if those also match.
if (val1 != -inf || val2 != -inf) // Same for -infinity.
ASSERT_EQUAL_TOL(val1, val2, 1e-10); CompiledExpression compiled1 = exp1.createCompiledExpression();
CompiledExpression compiled2 = exp2.createCompiledExpression();
if (compiled1.getVariables().find("x") != compiled1.getVariables().end())
compiled1.getVariableReference("x") = x;
if (compiled1.getVariables().find("y") != compiled1.getVariables().end())
compiled1.getVariableReference("y") = y;
if (compiled2.getVariables().find("x") != compiled2.getVariables().end())
compiled2.getVariableReference("x") = x;
if (compiled2.getVariables().find("y") != compiled2.getVariables().end())
compiled2.getVariableReference("y") = y;
assertNumbersEqual(val1, compiled1.evaluate());
assertNumbersEqual(val2, compiled2.evaluate());
} }
/** /**
...@@ -171,14 +194,14 @@ void testCustomFunction(const string& expression, const string& equivalent) { ...@@ -171,14 +194,14 @@ void testCustomFunction(const string& expression, const string& equivalent) {
verifySameValue(exp1, exp2, 2.0, 3.0); verifySameValue(exp1, exp2, 2.0, 3.0);
verifySameValue(exp1, exp2, -2.0, 3.0); verifySameValue(exp1, exp2, -2.0, 3.0);
verifySameValue(exp1, exp2, 2.0, -3.0); verifySameValue(exp1, exp2, 2.0, -3.0);
ParsedExpression deriv1 = exp1.differentiate("x"); ParsedExpression deriv1 = exp1.differentiate("x").optimize();
ParsedExpression deriv2 = exp2.differentiate("x"); ParsedExpression deriv2 = exp2.differentiate("x").optimize();
verifySameValue(deriv1, deriv2, 1.0, 2.0); verifySameValue(deriv1, deriv2, 1.0, 2.0);
verifySameValue(deriv1, deriv2, 2.0, 3.0); verifySameValue(deriv1, deriv2, 2.0, 3.0);
verifySameValue(deriv1, deriv2, -2.0, 3.0); verifySameValue(deriv1, deriv2, -2.0, 3.0);
verifySameValue(deriv1, deriv2, 2.0, -3.0); verifySameValue(deriv1, deriv2, 2.0, -3.0);
ParsedExpression deriv3 = deriv1.differentiate("y"); ParsedExpression deriv3 = deriv1.differentiate("y").optimize();
ParsedExpression deriv4 = deriv2.differentiate("y"); ParsedExpression deriv4 = deriv2.differentiate("y").optimize();
verifySameValue(deriv3, deriv4, 1.0, 2.0); verifySameValue(deriv3, deriv4, 1.0, 2.0);
verifySameValue(deriv3, deriv4, 2.0, 3.0); verifySameValue(deriv3, deriv4, 2.0, 3.0);
verifySameValue(deriv3, deriv4, -2.0, 3.0); verifySameValue(deriv3, deriv4, -2.0, 3.0);
...@@ -223,6 +246,7 @@ int main() { ...@@ -223,6 +246,7 @@ int main() {
verifyEvaluation("max(x, -1)", 2.0, 3.0, 2.0); verifyEvaluation("max(x, -1)", 2.0, 3.0, 2.0);
verifyEvaluation("abs(x-y)", 2.0, 3.0, 1.0); verifyEvaluation("abs(x-y)", 2.0, 3.0, 1.0);
verifyEvaluation("delta(x)+3*delta(y-1.5)", 2.0, 1.5, 3.0); verifyEvaluation("delta(x)+3*delta(y-1.5)", 2.0, 1.5, 3.0);
verifyEvaluation("step(x-3)+y*step(x)", 2.0, 3.0, 3.0);
verifyInvalidExpression("1..2"); verifyInvalidExpression("1..2");
verifyInvalidExpression("1*(2+3"); verifyInvalidExpression("1*(2+3");
verifyInvalidExpression("5++4"); verifyInvalidExpression("5++4");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment