Commit 4a25dc79 authored by peastman's avatar peastman
Browse files

Added more optimized operations to JIT compilation

parent a6fe70d2
......@@ -80,6 +80,7 @@ private:
CompiledExpression(const ParsedExpression& expression);
void compileExpression(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
void generateJitCode();
void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86XmmVar& dest, asmjit::X86XmmVar& arg, double (*function)(double));
int findTempIndex(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
std::vector<std::vector<int> > arguments;
std::vector<int> target;
......
......@@ -195,6 +195,10 @@ void CompiledExpression::generateJitCode() {
value = dynamic_cast<Operation::MultiplyConstant&>(op).getValue();
else if (op.getId() == Operation::RECIPROCAL)
value = 1.0;
else if (op.getId() == Operation::STEP)
value = 1.0;
else if (op.getId() == Operation::DELTA)
value = 1.0;
else
continue;
......@@ -232,6 +236,9 @@ void CompiledExpression::generateJitCode() {
for (int i = 1; i < op.getNumArguments(); i++)
args.push_back(args[0]+i);
}
// Generate instructions to execute this operation.
switch (op.getId()) {
case Operation::CONSTANT:
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
......@@ -259,6 +266,49 @@ void CompiledExpression::generateJitCode() {
case Operation::SQRT:
c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
break;
case Operation::LOG:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log);
break;
case Operation::SIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin);
break;
case Operation::COS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos);
break;
case Operation::TAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan);
break;
case Operation::ASIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin);
break;
case Operation::ACOS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos);
break;
case Operation::ATAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan);
break;
case Operation::SINH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh);
break;
case Operation::COSH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh);
break;
case Operation::TANH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
break;
case Operation::STEP:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::DELTA:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::SQUARE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
......@@ -280,7 +330,12 @@ void CompiledExpression::generateJitCode() {
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::ABS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
break;
default:
// Just invoke evaluateOperation().
for (int i = 0; i < (int) args.size(); i++)
c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
X86GpVar fn(c, kVarTypeIntPtr);
......@@ -295,3 +350,11 @@ void CompiledExpression::generateJitCode() {
c.endFunc();
jitCode = c.make();
}
void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86XmmVar& dest, X86XmmVar& arg, double (*function)(double)) {
X86GpVar fn(c, kVarTypeIntPtr);
c.mov(fn, imm_ptr((void*) function));
X86CallNode* call = c.call(fn, kFuncConvHost, FuncBuilder1<double, double>());
call->setArg(0, arg);
call->setRet(0, dest);
}
......@@ -127,6 +127,18 @@ void verifyInvalidExpression(const string& expression) {
throw exception();
}
/**
* Verify that two numbers have the same value.
*/
void assertNumbersEqual(double val1, double val2) {
const double inf = numeric_limits<double>::infinity();
if (val1 == val1 || val2 == val2) // If both are NaN, that's fine.
if (val1 != inf || val2 != inf) // Both infinity is also fine.
if (val1 != -inf || val2 != -inf) // Same for -infinity.
ASSERT_EQUAL_TOL(val1, val2, 1e-10);
}
/**
* Verify that two expressions give the same value.
*/
......@@ -137,11 +149,22 @@ void verifySameValue(const ParsedExpression& exp1, const ParsedExpression& exp2,
variables["y"] = y;
double val1 = exp1.evaluate(variables);
double val2 = exp2.evaluate(variables);
const double inf = numeric_limits<double>::infinity();
if (val1 == val1 || val2 == val2) // If both are NaN, that's fine.
if (val1 != inf || val2 != inf) // Both infinity is also fine.
if (val1 != -inf || val2 != -inf) // Same for -infinity.
ASSERT_EQUAL_TOL(val1, val2, 1e-10);
assertNumbersEqual(val1, val2);
// Now create CompiledExpressions from them and see if those also match.
CompiledExpression compiled1 = exp1.createCompiledExpression();
CompiledExpression compiled2 = exp2.createCompiledExpression();
if (compiled1.getVariables().find("x") != compiled1.getVariables().end())
compiled1.getVariableReference("x") = x;
if (compiled1.getVariables().find("y") != compiled1.getVariables().end())
compiled1.getVariableReference("y") = y;
if (compiled2.getVariables().find("x") != compiled2.getVariables().end())
compiled2.getVariableReference("x") = x;
if (compiled2.getVariables().find("y") != compiled2.getVariables().end())
compiled2.getVariableReference("y") = y;
assertNumbersEqual(val1, compiled1.evaluate());
assertNumbersEqual(val2, compiled2.evaluate());
}
/**
......@@ -171,14 +194,14 @@ void testCustomFunction(const string& expression, const string& equivalent) {
verifySameValue(exp1, exp2, 2.0, 3.0);
verifySameValue(exp1, exp2, -2.0, 3.0);
verifySameValue(exp1, exp2, 2.0, -3.0);
ParsedExpression deriv1 = exp1.differentiate("x");
ParsedExpression deriv2 = exp2.differentiate("x");
ParsedExpression deriv1 = exp1.differentiate("x").optimize();
ParsedExpression deriv2 = exp2.differentiate("x").optimize();
verifySameValue(deriv1, deriv2, 1.0, 2.0);
verifySameValue(deriv1, deriv2, 2.0, 3.0);
verifySameValue(deriv1, deriv2, -2.0, 3.0);
verifySameValue(deriv1, deriv2, 2.0, -3.0);
ParsedExpression deriv3 = deriv1.differentiate("y");
ParsedExpression deriv4 = deriv2.differentiate("y");
ParsedExpression deriv3 = deriv1.differentiate("y").optimize();
ParsedExpression deriv4 = deriv2.differentiate("y").optimize();
verifySameValue(deriv3, deriv4, 1.0, 2.0);
verifySameValue(deriv3, deriv4, 2.0, 3.0);
verifySameValue(deriv3, deriv4, -2.0, 3.0);
......@@ -223,6 +246,7 @@ int main() {
verifyEvaluation("max(x, -1)", 2.0, 3.0, 2.0);
verifyEvaluation("abs(x-y)", 2.0, 3.0, 1.0);
verifyEvaluation("delta(x)+3*delta(y-1.5)", 2.0, 1.5, 3.0);
verifyEvaluation("step(x-3)+y*step(x)", 2.0, 3.0, 3.0);
verifyInvalidExpression("1..2");
verifyInvalidExpression("1*(2+3");
verifyInvalidExpression("5++4");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment