Commit a83607d6 authored by peastman's avatar peastman
Browse files

Preliminary implementation of JIT compilation for CompiledExpressions

parent 4c19a401
// [AsmJit]
// Complete x86/x64 JIT and Remote Assembler for C++.
//
// [License]
// Zlib - See LICENSE.md file in the package.
// [Export]
#define ASMJIT_EXPORTS
#define ASMJIT_EXPORTS_X86OPERAND_REGS
// [Guard]
#include "../build.h"
#if defined(ASMJIT_BUILD_X86) || defined(ASMJIT_BUILD_X64)
// [Dependencies - AsmJit]
#include "../x86/x86operand.h"
// [Api-Begin]
#include "../apibegin.h"
namespace asmjit {
// Prevent static initialization.
//
// Remap all classes to POD structs so they can be statically initialized
// without calling a constructor. Compiler will store these in data section.
struct X86GpReg { Operand::VRegOp data; };
struct X86FpReg { Operand::VRegOp data; };
struct X86MmReg { Operand::VRegOp data; };
struct X86XmmReg { Operand::VRegOp data; };
struct X86YmmReg { Operand::VRegOp data; };
struct X86SegReg { Operand::VRegOp data; };
namespace x86 {
// ============================================================================
// [asmjit::x86::Registers]
// ============================================================================
#define REG(_Class_, _Name_, _Type_, _Index_, _Size_) \
const _Class_ _Name_ = {{ \
kOperandTypeReg, _Size_, { ((_Type_) << 8) + _Index_ }, kInvalidValue, {{ kInvalidVar, 0 }} \
}}
REG(X86GpReg, noGpReg, kInvalidReg, kInvalidReg, 0);
REG(X86GpReg, al, kX86RegTypeGpbLo, kX86RegIndexAx, 1);
REG(X86GpReg, cl, kX86RegTypeGpbLo, kX86RegIndexCx, 1);
REG(X86GpReg, dl, kX86RegTypeGpbLo, kX86RegIndexDx, 1);
REG(X86GpReg, bl, kX86RegTypeGpbLo, kX86RegIndexBx, 1);
REG(X86GpReg, spl, kX86RegTypeGpbLo, kX86RegIndexSp, 1);
REG(X86GpReg, bpl, kX86RegTypeGpbLo, kX86RegIndexBp, 1);
REG(X86GpReg, sil, kX86RegTypeGpbLo, kX86RegIndexSi, 1);
REG(X86GpReg, dil, kX86RegTypeGpbLo, kX86RegIndexDi, 1);
REG(X86GpReg, r8b, kX86RegTypeGpbLo, 8, 1);
REG(X86GpReg, r9b, kX86RegTypeGpbLo, 9, 1);
REG(X86GpReg, r10b, kX86RegTypeGpbLo, 10, 1);
REG(X86GpReg, r11b, kX86RegTypeGpbLo, 11, 1);
REG(X86GpReg, r12b, kX86RegTypeGpbLo, 12, 1);
REG(X86GpReg, r13b, kX86RegTypeGpbLo, 13, 1);
REG(X86GpReg, r14b, kX86RegTypeGpbLo, 14, 1);
REG(X86GpReg, r15b, kX86RegTypeGpbLo, 15, 1);
REG(X86GpReg, ah, kX86RegTypeGpbHi, kX86RegIndexAx, 1);
REG(X86GpReg, ch, kX86RegTypeGpbHi, kX86RegIndexCx, 1);
REG(X86GpReg, dh, kX86RegTypeGpbHi, kX86RegIndexDx, 1);
REG(X86GpReg, bh, kX86RegTypeGpbHi, kX86RegIndexBx, 1);
REG(X86GpReg, ax, kX86RegTypeGpw, kX86RegIndexAx, 2);
REG(X86GpReg, cx, kX86RegTypeGpw, kX86RegIndexCx, 2);
REG(X86GpReg, dx, kX86RegTypeGpw, kX86RegIndexDx, 2);
REG(X86GpReg, bx, kX86RegTypeGpw, kX86RegIndexBx, 2);
REG(X86GpReg, sp, kX86RegTypeGpw, kX86RegIndexSp, 2);
REG(X86GpReg, bp, kX86RegTypeGpw, kX86RegIndexBp, 2);
REG(X86GpReg, si, kX86RegTypeGpw, kX86RegIndexSi, 2);
REG(X86GpReg, di, kX86RegTypeGpw, kX86RegIndexDi, 2);
REG(X86GpReg, r8w, kX86RegTypeGpw, 8, 2);
REG(X86GpReg, r9w, kX86RegTypeGpw, 9, 2);
REG(X86GpReg, r10w, kX86RegTypeGpw, 10, 2);
REG(X86GpReg, r11w, kX86RegTypeGpw, 11, 2);
REG(X86GpReg, r12w, kX86RegTypeGpw, 12, 2);
REG(X86GpReg, r13w, kX86RegTypeGpw, 13, 2);
REG(X86GpReg, r14w, kX86RegTypeGpw, 14, 2);
REG(X86GpReg, r15w, kX86RegTypeGpw, 15, 2);
REG(X86GpReg, eax, kX86RegTypeGpd, kX86RegIndexAx, 4);
REG(X86GpReg, ecx, kX86RegTypeGpd, kX86RegIndexCx, 4);
REG(X86GpReg, edx, kX86RegTypeGpd, kX86RegIndexDx, 4);
REG(X86GpReg, ebx, kX86RegTypeGpd, kX86RegIndexBx, 4);
REG(X86GpReg, esp, kX86RegTypeGpd, kX86RegIndexSp, 4);
REG(X86GpReg, ebp, kX86RegTypeGpd, kX86RegIndexBp, 4);
REG(X86GpReg, esi, kX86RegTypeGpd, kX86RegIndexSi, 4);
REG(X86GpReg, edi, kX86RegTypeGpd, kX86RegIndexDi, 4);
REG(X86GpReg, r8d, kX86RegTypeGpd, 8, 4);
REG(X86GpReg, r9d, kX86RegTypeGpd, 9, 4);
REG(X86GpReg, r10d, kX86RegTypeGpd, 10, 4);
REG(X86GpReg, r11d, kX86RegTypeGpd, 11, 4);
REG(X86GpReg, r12d, kX86RegTypeGpd, 12, 4);
REG(X86GpReg, r13d, kX86RegTypeGpd, 13, 4);
REG(X86GpReg, r14d, kX86RegTypeGpd, 14, 4);
REG(X86GpReg, r15d, kX86RegTypeGpd, 15, 4);
REG(X86GpReg, rax, kX86RegTypeGpq, kX86RegIndexAx, 8);
REG(X86GpReg, rcx, kX86RegTypeGpq, kX86RegIndexCx, 8);
REG(X86GpReg, rdx, kX86RegTypeGpq, kX86RegIndexDx, 8);
REG(X86GpReg, rbx, kX86RegTypeGpq, kX86RegIndexBx, 8);
REG(X86GpReg, rsp, kX86RegTypeGpq, kX86RegIndexSp, 8);
REG(X86GpReg, rbp, kX86RegTypeGpq, kX86RegIndexBp, 8);
REG(X86GpReg, rsi, kX86RegTypeGpq, kX86RegIndexSi, 8);
REG(X86GpReg, rdi, kX86RegTypeGpq, kX86RegIndexDi, 8);
REG(X86GpReg, r8, kX86RegTypeGpq, 8, 8);
REG(X86GpReg, r9, kX86RegTypeGpq, 9, 8);
REG(X86GpReg, r10, kX86RegTypeGpq, 10, 8);
REG(X86GpReg, r11, kX86RegTypeGpq, 11, 8);
REG(X86GpReg, r12, kX86RegTypeGpq, 12, 8);
REG(X86GpReg, r13, kX86RegTypeGpq, 13, 8);
REG(X86GpReg, r14, kX86RegTypeGpq, 14, 8);
REG(X86GpReg, r15, kX86RegTypeGpq, 15, 8);
REG(X86FpReg, fp0, kX86RegTypeFp, 0, 10);
REG(X86FpReg, fp1, kX86RegTypeFp, 1, 10);
REG(X86FpReg, fp2, kX86RegTypeFp, 2, 10);
REG(X86FpReg, fp3, kX86RegTypeFp, 3, 10);
REG(X86FpReg, fp4, kX86RegTypeFp, 4, 10);
REG(X86FpReg, fp5, kX86RegTypeFp, 5, 10);
REG(X86FpReg, fp6, kX86RegTypeFp, 6, 10);
REG(X86FpReg, fp7, kX86RegTypeFp, 7, 10);
REG(X86MmReg, mm0, kX86RegTypeMm, 0, 8);
REG(X86MmReg, mm1, kX86RegTypeMm, 1, 8);
REG(X86MmReg, mm2, kX86RegTypeMm, 2, 8);
REG(X86MmReg, mm3, kX86RegTypeMm, 3, 8);
REG(X86MmReg, mm4, kX86RegTypeMm, 4, 8);
REG(X86MmReg, mm5, kX86RegTypeMm, 5, 8);
REG(X86MmReg, mm6, kX86RegTypeMm, 6, 8);
REG(X86MmReg, mm7, kX86RegTypeMm, 7, 8);
REG(X86XmmReg, xmm0, kX86RegTypeXmm, 0, 16);
REG(X86XmmReg, xmm1, kX86RegTypeXmm, 1, 16);
REG(X86XmmReg, xmm2, kX86RegTypeXmm, 2, 16);
REG(X86XmmReg, xmm3, kX86RegTypeXmm, 3, 16);
REG(X86XmmReg, xmm4, kX86RegTypeXmm, 4, 16);
REG(X86XmmReg, xmm5, kX86RegTypeXmm, 5, 16);
REG(X86XmmReg, xmm6, kX86RegTypeXmm, 6, 16);
REG(X86XmmReg, xmm7, kX86RegTypeXmm, 7, 16);
REG(X86XmmReg, xmm8, kX86RegTypeXmm, 8, 16);
REG(X86XmmReg, xmm9, kX86RegTypeXmm, 9, 16);
REG(X86XmmReg, xmm10, kX86RegTypeXmm, 10, 16);
REG(X86XmmReg, xmm11, kX86RegTypeXmm, 11, 16);
REG(X86XmmReg, xmm12, kX86RegTypeXmm, 12, 16);
REG(X86XmmReg, xmm13, kX86RegTypeXmm, 13, 16);
REG(X86XmmReg, xmm14, kX86RegTypeXmm, 14, 16);
REG(X86XmmReg, xmm15, kX86RegTypeXmm, 15, 16);
REG(X86YmmReg, ymm0, kX86RegTypeYmm, 0, 32);
REG(X86YmmReg, ymm1, kX86RegTypeYmm, 1, 32);
REG(X86YmmReg, ymm2, kX86RegTypeYmm, 2, 32);
REG(X86YmmReg, ymm3, kX86RegTypeYmm, 3, 32);
REG(X86YmmReg, ymm4, kX86RegTypeYmm, 4, 32);
REG(X86YmmReg, ymm5, kX86RegTypeYmm, 5, 32);
REG(X86YmmReg, ymm6, kX86RegTypeYmm, 6, 32);
REG(X86YmmReg, ymm7, kX86RegTypeYmm, 7, 32);
REG(X86YmmReg, ymm8, kX86RegTypeYmm, 8, 32);
REG(X86YmmReg, ymm9, kX86RegTypeYmm, 9, 32);
REG(X86YmmReg, ymm10, kX86RegTypeYmm, 10, 32);
REG(X86YmmReg, ymm11, kX86RegTypeYmm, 11, 32);
REG(X86YmmReg, ymm12, kX86RegTypeYmm, 12, 32);
REG(X86YmmReg, ymm13, kX86RegTypeYmm, 13, 32);
REG(X86YmmReg, ymm14, kX86RegTypeYmm, 14, 32);
REG(X86YmmReg, ymm15, kX86RegTypeYmm, 15, 32);
REG(X86SegReg, cs, kX86RegTypeSeg, kX86SegCs, 2);
REG(X86SegReg, ss, kX86RegTypeSeg, kX86SegSs, 2);
REG(X86SegReg, ds, kX86RegTypeSeg, kX86SegDs, 2);
REG(X86SegReg, es, kX86RegTypeSeg, kX86SegEs, 2);
REG(X86SegReg, fs, kX86RegTypeSeg, kX86SegFs, 2);
REG(X86SegReg, gs, kX86RegTypeSeg, kX86SegGs, 2);
#undef REG
} // x86 namespace
} // asmjit namespace
// [Api-End]
#include "../apiend.h"
// [Guard]
#endif // ASMJIT_BUILD_X86 || ASMJIT_BUILD_X64
// [AsmJit]
// Complete x86/x64 JIT and Remote Assembler for C++.
//
// [License]
// Zlib - See LICENSE.md file in the package.
// [Export]
#define ASMJIT_EXPORTS
// [Guard]
#include "../build.h"
#if !defined(ASMJIT_DISABLE_COMPILER) && (defined(ASMJIT_BUILD_X86) || defined(ASMJIT_BUILD_X64))
// [Dependencies - AsmJit]
#include "../base/containers.h"
#include "../x86/x86scheduler_p.h"
// [Api-Begin]
#include "../apibegin.h"
namespace asmjit {
// ============================================================================
// [Internals]
// ============================================================================
//! \internal
struct X86ScheduleData {
//! Registers read by the instruction.
X86RegMask regsIn;
//! Registers written by the instruction.
X86RegMask regsOut;
//! Flags read by the instruction.
uint8_t flagsIn;
//! Flags written by the instruction.
uint8_t flagsOut;
//! How many `uops` or `cycles` the instruction takes.
uint8_t ops;
//! Instruction latency.
uint8_t latency;
//! Which ports the instruction can run at.
uint16_t ports;
//! \internal
uint16_t reserved;
//! All instructions that this instruction depends on.
PodList<InstNode*>::Link* dependsOn;
//! All instructions that use the result of this instruction.
PodList<InstNode*>::Link* usedBy;
};
// ============================================================================
// [asmjit::X86Scheduler - Construction / Destruction]
// ============================================================================
X86Scheduler::X86Scheduler(X86Compiler* compiler, const X86CpuInfo* cpuInfo) :
_compiler(compiler),
_cpuInfo(cpuInfo) {}
X86Scheduler::~X86Scheduler() {}
// ============================================================================
// [asmjit::X86Scheduler - Run]
// ============================================================================
Error X86Scheduler::run(Node* start, Node* stop) {
/*
ASMJIT_TLOG("[Schedule] === Begin ===");
Zone zone(8096 - kZoneOverhead);
Node* node_ = start;
while (node_ != stop) {
Node* next = node_->getNext();
ASMJIT_ASSERT(node_->getType() == kNodeTypeInst);
printf(" %s\n", X86Util::getInstInfo(static_cast<InstNode*>(node_)->getCode()).getInstName());
node_ = next;
}
ASMJIT_TLOG("[Schedule] === End ===");
*/
return kErrorOk;
}
} // asmjit namespace
// [Api-End]
#include "../apiend.h"
// [Guard]
#endif // !ASMJIT_DISABLE_COMPILER && (ASMJIT_BUILD_X86 || ASMJIT_BUILD_X64)
// [AsmJit]
// Complete x86/x64 JIT and Remote Assembler for C++.
//
// [License]
// Zlib - See LICENSE.md file in the package.
// [Guard]
#ifndef _ASMJIT_X86_X86SCHEDULER_P_H
#define _ASMJIT_X86_X86SCHEDULER_P_H
#include "../build.h"
#if !defined(ASMJIT_DISABLE_COMPILER)
// [Dependencies - AsmJit]
#include "../x86/x86compiler.h"
#include "../x86/x86context_p.h"
#include "../x86/x86cpuinfo.h"
#include "../x86/x86inst.h"
// [Api-Begin]
#include "../apibegin.h"
namespace asmjit {
// ============================================================================
// [asmjit::X86Scheduler]
// ============================================================================
//! \internal
//!
//! X86 scheduler.
struct X86Scheduler {
// --------------------------------------------------------------------------
// [Construction / Destruction]
// --------------------------------------------------------------------------
X86Scheduler(X86Compiler* compiler, const X86CpuInfo* cpuInfo);
~X86Scheduler();
// --------------------------------------------------------------------------
// [Run]
// --------------------------------------------------------------------------
Error run(Node* start, Node* stop);
// --------------------------------------------------------------------------
// [Members]
// --------------------------------------------------------------------------
//! Attached compiler.
X86Compiler* _compiler;
//! CPU information used for scheduling.
const X86CpuInfo* _cpuInfo;
};
} // asmjit namespace
// [Api-End]
#include "../apiend.h"
// [Guard]
#endif // !ASMJIT_DISABLE_COMPILER
#endif // _ASMJIT_X86_X86SCHEDULER_P_H
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "ExpressionTreeNode.h" #include "ExpressionTreeNode.h"
#include "windowsIncludes.h" #include "windowsIncludes.h"
#include "asmjit.h"
#include <map> #include <map>
#include <set> #include <set>
#include <string> #include <string>
...@@ -78,6 +79,7 @@ private: ...@@ -78,6 +79,7 @@ private:
friend class ParsedExpression; friend class ParsedExpression;
CompiledExpression(const ParsedExpression& expression); CompiledExpression(const ParsedExpression& expression);
void compileExpression(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps); void compileExpression(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
void generateJitCode();
int findTempIndex(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps); int findTempIndex(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
std::vector<std::vector<int> > arguments; std::vector<std::vector<int> > arguments;
std::vector<int> target; std::vector<int> target;
...@@ -87,6 +89,8 @@ private: ...@@ -87,6 +89,8 @@ private:
mutable std::vector<double> workspace; mutable std::vector<double> workspace;
mutable std::vector<double> argValues; mutable std::vector<double> argValues;
std::map<std::string, double> dummyVariables; std::map<std::string, double> dummyVariables;
asmjit::JitRuntime runtime;
void* jitCode;
}; };
} // namespace Lepton } // namespace Lepton
......
...@@ -36,14 +36,21 @@ ...@@ -36,14 +36,21 @@
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
using namespace asmjit;
CompiledExpression::CompiledExpression() { CompiledExpression::CompiledExpression() : jitCode(NULL) {
} }
CompiledExpression::CompiledExpression(const ParsedExpression& expression) { CompiledExpression::CompiledExpression(const ParsedExpression& expression) : jitCode(NULL) {
ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized. ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized.
vector<pair<ExpressionTreeNode, int> > temps; vector<pair<ExpressionTreeNode, int> > temps;
compileExpression(expr.getRootNode(), temps); compileExpression(expr.getRootNode(), temps);
int maxArguments = 1;
for (int i = 0; i < (int) operation.size(); i++)
if (operation[i]->getNumArguments() > maxArguments)
maxArguments = operation[i]->getNumArguments();
argValues.resize(maxArguments);
generateJitCode();
} }
CompiledExpression::~CompiledExpression() { CompiledExpression::~CompiledExpression() {
...@@ -52,7 +59,7 @@ CompiledExpression::~CompiledExpression() { ...@@ -52,7 +59,7 @@ CompiledExpression::~CompiledExpression() {
delete operation[i]; delete operation[i];
} }
CompiledExpression::CompiledExpression(const CompiledExpression& expression) { CompiledExpression::CompiledExpression(const CompiledExpression& expression) : jitCode(NULL) {
*this = expression; *this = expression;
} }
...@@ -66,6 +73,7 @@ CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expr ...@@ -66,6 +73,7 @@ CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expr
operation.resize(expression.operation.size()); operation.resize(expression.operation.size());
for (int i = 0; i < (int) operation.size(); i++) for (int i = 0; i < (int) operation.size(); i++)
operation[i] = expression.operation[i]->clone(); operation[i] = expression.operation[i]->clone();
generateJitCode();
return *this; return *this;
} }
...@@ -103,11 +111,8 @@ void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vecto ...@@ -103,11 +111,8 @@ void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vecto
sequential = false; sequential = false;
if (sequential) if (sequential)
arguments[stepIndex].push_back(args[0]); arguments[stepIndex].push_back(args[0]);
else { else
arguments[stepIndex] = args; arguments[stepIndex] = args;
if (args.size() > argValues.size())
argValues.resize(args.size(), 0.0);
}
} }
} }
temps.push_back(make_pair(node, workspace.size())); temps.push_back(make_pair(node, workspace.size()));
...@@ -133,6 +138,9 @@ double& CompiledExpression::getVariableReference(const string& name) { ...@@ -133,6 +138,9 @@ double& CompiledExpression::getVariableReference(const string& name) {
} }
double CompiledExpression::evaluate() const { double CompiledExpression::evaluate() const {
if (jitCode != NULL)
return ((double (*)()) jitCode)();
// Loop over the operations and evaluate each one. // Loop over the operations and evaluate each one.
for (int step = 0; step < operation.size(); step++) { for (int step = 0; step < operation.size(); step++) {
...@@ -147,3 +155,54 @@ double CompiledExpression::evaluate() const { ...@@ -147,3 +155,54 @@ double CompiledExpression::evaluate() const {
} }
return workspace[workspace.size()-1]; return workspace[workspace.size()-1];
} }
static double evaluateOperation(Operation* op, double* args) {
map<string, double>* dummyVariables = NULL;
return op->evaluate(args, *dummyVariables);
}
void CompiledExpression::generateJitCode() {
X86Compiler c(&runtime);
c.addFunc(kFuncConvHost, FuncBuilder0<double>());
vector<X86XmmVar> workspaceVar(workspace.size());
for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newXmmVar(kX86VarTypeXmmSd);
X86GpVar workspacePointer(c);
X86GpVar argsPointer(c);
c.mov(workspacePointer, imm_ptr(&workspace[0]));
c.mov(argsPointer, imm_ptr(&argValues[0]));
// Load the variables.
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter);
c.movsd(workspaceVar[index->second], x86::ptr(workspacePointer, 8*index->second, 0));
}
// Evaluate the operations.
for (int step = 0; step < (int) operation.size(); step++) {
const vector<int>& args = arguments[step];
if (args.size() == 1) {
// One or more sequential arguments.
for (int i = 0; i < operation[step]->getNumArguments(); i++)
c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[0]+i]);
}
else {
// Two or more non-sequential arguments.
for (int i = 0; i < (int) args.size(); i++)
c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
}
X86GpVar fn(c, kVarTypeIntPtr);
c.mov(fn, imm_ptr((void*) evaluateOperation));
X86CallNode* call = c.call(fn, kFuncConvHost, FuncBuilder2<double, Operation*, double*>());
call->setArg(0, imm_ptr(operation[step]));
call->setArg(1, imm_ptr(&argValues[0]));
call->setRet(0, workspaceVar[target[step]]);
}
c.ret(workspacePointer);
c.endFunc();
jitCode = c.make();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment