Unverified Commit aafb8b5b authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Vectorize custom expressions on CPU (#3552)

* Began implementing vectorization of Lepton expressions

* Tests for vector expressions

* Implemented CompiledVectorExpression for x86

* Bug fix

* Optimized select() on ARM

* Optimized select() on x86

* CompiledVectorExpression supports AVX

* Bug fix

* Updated docs

* Use VEX encoded instructions for CompiledExpression

* Optimized min() and max() on x86

* Optimized min() and max() on ARM

* Fixed compilation error

* Upgrade AsmJit
parent 6fb1c8a4
...@@ -4993,7 +4993,7 @@ EmitDone: ...@@ -4993,7 +4993,7 @@ EmitDone:
if (Support::test(options, InstOptions::kReserved)) { if (Support::test(options, InstOptions::kReserved)) {
#ifndef ASMJIT_NO_LOGGING #ifndef ASMJIT_NO_LOGGING
if (_logger) if (_logger)
EmitterUtils::logInstructionEmitted(this, instId, options, o0, o1, o2, opExt, 0, 0, writer.cursor()); EmitterUtils::logInstructionEmitted(this, BaseInst::composeARMInstId(instId, instCC), options, o0, o1, o2, opExt, 0, 0, writer.cursor());
#endif #endif
} }
......
...@@ -169,6 +169,18 @@ public: ...@@ -169,6 +169,18 @@ public:
//! \} //! \}
//! \name Compiler specific
//! \{
//! Special pseudo-instruction that can be used to load a memory address into `o0` GP register.
//!
//! \note At the moment this instruction is only useful to load a stack allocated address into a GP register
//! for further use. It makes very little sense to use it for anything else. The semantics of this instruction
//! is the same as X86 `LEA` (load effective address) instruction.
inline Error loadAddressOf(const Gp& o0, const Mem& o1) { return _emitter()->_emitI(Inst::kIdAdr, o0, o1); }
//! \}
//! \name Function Call & Ret Intrinsics //! \name Function Call & Ret Intrinsics
//! \{ //! \{
......
...@@ -117,7 +117,7 @@ ASMJIT_FAVOR_SIZE Error EmitHelper::emitRegMove( ...@@ -117,7 +117,7 @@ ASMJIT_FAVOR_SIZE Error EmitHelper::emitRegMove(
case TypeId::kUInt32: case TypeId::kUInt32:
case TypeId::kInt64: case TypeId::kInt64:
case TypeId::kUInt64: case TypeId::kUInt64:
return emitter->mov(src.as<Gp>().x(), dst.as<Gp>().x()); return emitter->mov(dst.as<Gp>().x(), src.as<Gp>().x());
default: { default: {
if (TypeUtils::isFloat32(typeId) || TypeUtils::isVec32(typeId)) if (TypeUtils::isFloat32(typeId) || TypeUtils::isVec32(typeId))
......
...@@ -139,7 +139,7 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* ...@@ -139,7 +139,7 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_*
if (ASMJIT_UNLIKELY(!Inst::isDefinedId(realId))) if (ASMJIT_UNLIKELY(!Inst::isDefinedId(realId)))
return DebugUtils::errored(kErrorInvalidInstruction); return DebugUtils::errored(kErrorInvalidInstruction);
out->_instFlags = 0; out->_instFlags = InstRWFlags::kNone;
out->_opCount = uint8_t(opCount); out->_opCount = uint8_t(opCount);
out->_rmFeature = 0; out->_rmFeature = 0;
out->_extraReg.reset(); out->_extraReg.reset();
......
...@@ -102,7 +102,7 @@ public: ...@@ -102,7 +102,7 @@ public:
// TODO: [ARM] This is just a workaround... // TODO: [ARM] This is just a workaround...
static InstControlFlow getControlFlowType(InstId instId) noexcept { static InstControlFlow getControlFlowType(InstId instId) noexcept {
switch (instId) { switch (BaseInst::extractRealId(instId)) {
case Inst::kIdB: case Inst::kIdB:
case Inst::kIdBr: case Inst::kIdBr:
if (BaseInst::extractARMCondCode(instId) == CondCode::kAL) if (BaseInst::extractARMCondCode(instId) == CondCode::kAL)
...@@ -127,8 +127,8 @@ static InstControlFlow getControlFlowType(InstId instId) noexcept { ...@@ -127,8 +127,8 @@ static InstControlFlow getControlFlowType(InstId instId) noexcept {
Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& controlType, RAInstBuilder& ib) noexcept { Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& controlType, RAInstBuilder& ib) noexcept {
InstRWInfo rwInfo; InstRWInfo rwInfo;
if (Inst::isDefinedId(inst->realId())) {
InstId instId = inst->id(); InstId instId = inst->id();
if (Inst::isDefinedId(instId)) {
uint32_t opCount = inst->opCount(); uint32_t opCount = inst->opCount();
const Operand* opArray = inst->operands(); const Operand* opArray = inst->operands();
ASMJIT_PROPAGATE(InstInternal::queryRWInfo(_arch, inst->baseInst(), opArray, opCount, &rwInfo)); ASMJIT_PROPAGATE(InstInternal::queryRWInfo(_arch, inst->baseInst(), opArray, opCount, &rwInfo));
...@@ -136,6 +136,8 @@ Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& controlType, RAInstB ...@@ -136,6 +136,8 @@ Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& controlType, RAInstB
const InstDB::InstInfo& instInfo = InstDB::infoById(instId); const InstDB::InstInfo& instInfo = InstDB::infoById(instId);
uint32_t singleRegOps = 0; uint32_t singleRegOps = 0;
ib.addInstRWFlags(rwInfo.instFlags());
if (opCount) { if (opCount) {
uint32_t consecutiveOffset = 0xFFFFFFFFu; uint32_t consecutiveOffset = 0xFFFFFFFFu;
uint32_t consecutiveParent = Globals::kInvalidId; uint32_t consecutiveParent = Globals::kInvalidId;
...@@ -715,6 +717,50 @@ ASMJIT_FAVOR_SPEED Error ARMRAPass::_rewrite(BaseNode* first, BaseNode* stop) no ...@@ -715,6 +717,50 @@ ASMJIT_FAVOR_SPEED Error ARMRAPass::_rewrite(BaseNode* first, BaseNode* stop) no
} }
} }
} }
// Rewrite `loadAddressOf()` construct.
if (inst->realId() == Inst::kIdAdr && inst->opCount() == 2 && inst->op(1).isMem()) {
BaseMem mem = inst->op(1).as<BaseMem>();
int64_t offset = mem.offset();
if (!mem.hasBaseOrIndex()) {
inst->setId(Inst::kIdMov);
inst->setOp(1, Imm(offset));
}
else {
if (mem.hasIndex())
return DebugUtils::errored(kErrorInvalidAddressIndex);
GpX dst(inst->op(0).as<Gp>().id());
GpX base(mem.baseId());
InstId arithInstId = offset < 0 ? Inst::kIdSub : Inst::kIdAdd;
uint64_t absOffset = offset < 0 ? Support::neg(uint64_t(offset)) : uint64_t(offset);
inst->setId(arithInstId);
inst->setOpCount(3);
inst->setOp(1, base);
inst->setOp(2, Imm(absOffset));
// Use two operations if the offset cannot be encoded with ADD/SUB.
if (absOffset > 0xFFFu && (absOffset & ~uint64_t(0xFFF000u)) != 0) {
if (absOffset <= 0xFFFFFFu) {
cc()->_setCursor(inst->prev());
ASMJIT_PROPAGATE(cc()->emit(arithInstId, dst, base, Imm(absOffset & 0xFFFu)));
inst->setOp(1, dst);
inst->setOp(2, Imm(absOffset & 0xFFF000u));
}
else {
cc()->_setCursor(inst->prev());
ASMJIT_PROPAGATE(cc()->emit(Inst::kIdMov, inst->op(0), Imm(absOffset)));
inst->setOp(1, base);
inst->setOp(2, dst);
}
}
}
}
} }
node = next; node = next;
......
...@@ -152,7 +152,7 @@ enum class InstHints : uint8_t { ...@@ -152,7 +152,7 @@ enum class InstHints : uint8_t {
//! No feature hints. //! No feature hints.
kNoHints = 0, kNoHints = 0,
//! Architecture supports a register swap by using a single instructio. //! Architecture supports a register swap by using a single instruction.
kRegSwap = 0x01u, kRegSwap = 0x01u,
//! Architecture provides push/pop instructions. //! Architecture provides push/pop instructions.
kPushPop = 0x02u kPushPop = 0x02u
......
...@@ -356,7 +356,7 @@ struct OffsetFormat { ...@@ -356,7 +356,7 @@ struct OffsetFormat {
//! Returns the size of the region/instruction where the offset is encoded. //! Returns the size of the region/instruction where the offset is encoded.
inline uint32_t regionSize() const noexcept { return _regionSize; } inline uint32_t regionSize() const noexcept { return _regionSize; }
//! Returns the the offset of the word relative to the start of the region where the offset is. //! Returns the offset of the word relative to the start of the region where the offset is.
inline uint32_t valueOffset() const noexcept { return _valueOffset; } inline uint32_t valueOffset() const noexcept { return _valueOffset; }
//! Returns the size of the data-type (word) that contains the offset, in bytes. //! Returns the size of the data-type (word) that contains the offset, in bytes.
......
...@@ -143,7 +143,7 @@ Error formatLabel( ...@@ -143,7 +143,7 @@ Error formatLabel(
} }
if (le->type() == LabelType::kAnonymous) if (le->type() == LabelType::kAnonymous)
ASMJIT_PROPAGATE(sb.append("L%u@", labelId)); ASMJIT_PROPAGATE(sb.appendFormat("L%u@", labelId));
return sb.append(le->name()); return sb.append(le->name());
} }
else { else {
......
...@@ -1127,7 +1127,7 @@ public: ...@@ -1127,7 +1127,7 @@ public:
//! Tests whether the callee must adjust SP before returning (X86-STDCALL only) //! Tests whether the callee must adjust SP before returning (X86-STDCALL only)
inline bool hasCalleeStackCleanup() const noexcept { return _calleeStackCleanup != 0; } inline bool hasCalleeStackCleanup() const noexcept { return _calleeStackCleanup != 0; }
//! Returns home many bytes of the stack the the callee must adjust before returning (X86-STDCALL only) //! Returns home many bytes of the stack the callee must adjust before returning (X86-STDCALL only)
inline uint32_t calleeStackCleanup() const noexcept { return _calleeStackCleanup; } inline uint32_t calleeStackCleanup() const noexcept { return _calleeStackCleanup; }
//! Returns call stack alignment. //! Returns call stack alignment.
......
...@@ -312,6 +312,10 @@ public: ...@@ -312,6 +312,10 @@ public:
return id | (uint32_t(cc) << Support::ConstCTZ<uint32_t(InstIdParts::kARM_Cond)>::value); return id | (uint32_t(cc) << Support::ConstCTZ<uint32_t(InstIdParts::kARM_Cond)>::value);
} }
static inline constexpr InstId extractRealId(uint32_t id) noexcept {
return id & uint32_t(InstIdParts::kRealId);
}
static inline constexpr arm::CondCode extractARMCondCode(uint32_t id) noexcept { static inline constexpr arm::CondCode extractARMCondCode(uint32_t id) noexcept {
return (arm::CondCode)((uint32_t(id) & uint32_t(InstIdParts::kARM_Cond)) >> Support::ConstCTZ<uint32_t(InstIdParts::kARM_Cond)>::value); return (arm::CondCode)((uint32_t(id) & uint32_t(InstIdParts::kARM_Cond)) >> Support::ConstCTZ<uint32_t(InstIdParts::kARM_Cond)>::value);
} }
...@@ -614,13 +618,25 @@ struct OpRWInfo { ...@@ -614,13 +618,25 @@ struct OpRWInfo {
//! \} //! \}
}; };
//! Flags used by \ref InstRWInfo.
enum class InstRWFlags : uint32_t {
//! No flags.
kNone = 0x00000000u,
//! Describes a move operation.
//!
//! This flag is used by RA to eliminate moves that are guaranteed to be moves only.
kMovOp = 0x00000001u
};
ASMJIT_DEFINE_ENUM_FLAGS(InstRWFlags)
//! Read/Write information of an instruction. //! Read/Write information of an instruction.
struct InstRWInfo { struct InstRWInfo {
//! \name Members //! \name Members
//! \{ //! \{
//! Instruction flags (there are no flags at the moment, this field is reserved). //! Instruction flags (there are no flags at the moment, this field is reserved).
uint32_t _instFlags; InstRWFlags _instFlags;
//! CPU flags read. //! CPU flags read.
CpuRWFlags _readFlags; CpuRWFlags _readFlags;
//! CPU flags written. //! CPU flags written.
...@@ -646,6 +662,20 @@ struct InstRWInfo { ...@@ -646,6 +662,20 @@ struct InstRWInfo {
//! \} //! \}
//! \name Instruction Flags
//! \{
//! Returns flags associated with the instruction, see \ref InstRWFlags.
inline InstRWFlags instFlags() const noexcept { return _instFlags; }
//! Tests whether the instruction flags contain `flag`.
inline bool hasInstFlag(InstRWFlags flag) const noexcept { return Support::test(_instFlags, flag); }
//! Tests whether the instruction flags contain \ref InstRWFlags::kMovOp.
inline bool isMovOp() const noexcept { return hasInstFlag(InstRWFlags::kMovOp); }
//! \}
//! \name CPU Flags Information //! \name CPU Flags Information
//! \{ //! \{
......
...@@ -836,6 +836,34 @@ Error RALocalAllocator::allocInst(InstNode* node) noexcept { ...@@ -836,6 +836,34 @@ Error RALocalAllocator::allocInst(InstNode* node) noexcept {
// STEP 9 // STEP 9
// ------ // ------
// //
// Vector registers can be cloberred partially by invoke - find if that's the case and clobber when necessary.
if (node->isInvoke() && group == RegGroup::kVec) {
const InvokeNode* invokeNode = node->as<InvokeNode>();
RegMask maybeClobberedRegs = invokeNode->detail().callConv().preservedRegs(group) & _curAssignment.assigned(group);
if (maybeClobberedRegs) {
uint32_t saveRestoreVecSize = invokeNode->detail().callConv().saveRestoreRegSize(group);
Support::BitWordIterator<RegMask> it(maybeClobberedRegs);
do {
uint32_t physId = it.next();
uint32_t workId = _curAssignment.physToWorkId(group, physId);
RAWorkReg* workReg = workRegById(workId);
uint32_t virtSize = workReg->virtReg()->virtSize();
if (virtSize > saveRestoreVecSize) {
ASMJIT_PROPAGATE(onSpillReg(group, workId, physId));
}
} while (it.hasNext());
}
}
// STEP 10
// -------
//
// Assign OUT registers. // Assign OUT registers.
if (outPending) { if (outPending) {
......
...@@ -276,6 +276,8 @@ public: ...@@ -276,6 +276,8 @@ public:
//! Parent block. //! Parent block.
RABlock* _block; RABlock* _block;
//! Instruction RW flags.
InstRWFlags _instRWFlags;
//! Aggregated RATiedFlags from all operands & instruction specific flags. //! Aggregated RATiedFlags from all operands & instruction specific flags.
RATiedFlags _flags; RATiedFlags _flags;
//! Total count of RATiedReg's. //! Total count of RATiedReg's.
...@@ -298,9 +300,10 @@ public: ...@@ -298,9 +300,10 @@ public:
//! \name Construction & Destruction //! \name Construction & Destruction
//! \{ //! \{
inline RAInst(RABlock* block, RATiedFlags flags, uint32_t tiedTotal, const RARegMask& clobberedRegs) noexcept { inline RAInst(RABlock* block, InstRWFlags instRWFlags, RATiedFlags tiedFlags, uint32_t tiedTotal, const RARegMask& clobberedRegs) noexcept {
_block = block; _block = block;
_flags = flags; _instRWFlags = instRWFlags;
_flags = tiedFlags;
_tiedTotal = tiedTotal; _tiedTotal = tiedTotal;
_tiedIndex.reset(); _tiedIndex.reset();
_tiedCount.reset(); _tiedCount.reset();
...@@ -314,6 +317,13 @@ public: ...@@ -314,6 +317,13 @@ public:
//! \name Accessors //! \name Accessors
//! \{ //! \{
//! Returns instruction RW flags.
inline InstRWFlags instRWFlags() const noexcept { return _instRWFlags; };
//! Tests whether the given `flag` is present in instruction RW flags.
inline bool hasInstRWFlag(InstRWFlags flag) const noexcept { return Support::test(_instRWFlags, flag); }
//! Adds `flags` to instruction RW flags.
inline void addInstRWFlags(InstRWFlags flags) noexcept { _instRWFlags |= flags; }
//! Returns the instruction flags. //! Returns the instruction flags.
inline RATiedFlags flags() const noexcept { return _flags; } inline RATiedFlags flags() const noexcept { return _flags; }
//! Tests whether the instruction has flag `flag`. //! Tests whether the instruction has flag `flag`.
...@@ -376,6 +386,9 @@ public: ...@@ -376,6 +386,9 @@ public:
//! \name Members //! \name Members
//! \{ //! \{
//! Instruction RW flags.
InstRWFlags _instRWFlags;
//! Flags combined from all RATiedReg's. //! Flags combined from all RATiedReg's.
RATiedFlags _aggregatedFlags; RATiedFlags _aggregatedFlags;
//! Flags that will be cleared before storing the aggregated flags to `RAInst`. //! Flags that will be cleared before storing the aggregated flags to `RAInst`.
...@@ -400,6 +413,7 @@ public: ...@@ -400,6 +413,7 @@ public:
inline void init() noexcept { reset(); } inline void init() noexcept { reset(); }
inline void reset() noexcept { inline void reset() noexcept {
_instRWFlags = InstRWFlags::kNone;
_aggregatedFlags = RATiedFlags::kNone; _aggregatedFlags = RATiedFlags::kNone;
_forbiddenFlags = RATiedFlags::kNone; _forbiddenFlags = RATiedFlags::kNone;
_count.reset(); _count.reset();
...@@ -414,10 +428,15 @@ public: ...@@ -414,10 +428,15 @@ public:
//! \name Accessors //! \name Accessors
//! \{ //! \{
inline RATiedFlags aggregatedFlags() const noexcept { return _aggregatedFlags; } inline InstRWFlags instRWFlags() const noexcept { return _instRWFlags; }
inline RATiedFlags forbiddenFlags() const noexcept { return _forbiddenFlags; } inline bool hasInstRWFlag(InstRWFlags flag) const noexcept { return Support::test(_instRWFlags, flag); }
inline void addInstRWFlags(InstRWFlags flags) noexcept { _instRWFlags |= flags; }
inline void clearInstRWFlags(InstRWFlags flags) noexcept { _instRWFlags &= ~flags; }
inline RATiedFlags aggregatedFlags() const noexcept { return _aggregatedFlags; }
inline void addAggregatedFlags(RATiedFlags flags) noexcept { _aggregatedFlags |= flags; } inline void addAggregatedFlags(RATiedFlags flags) noexcept { _aggregatedFlags |= flags; }
inline RATiedFlags forbiddenFlags() const noexcept { return _forbiddenFlags; }
inline void addForbiddenFlags(RATiedFlags flags) noexcept { _forbiddenFlags |= flags; } inline void addForbiddenFlags(RATiedFlags flags) noexcept { _forbiddenFlags |= flags; }
//! Returns the number of tied registers added to the builder. //! Returns the number of tied registers added to the builder.
...@@ -859,16 +878,16 @@ public: ...@@ -859,16 +878,16 @@ public:
return _exits.append(allocator(), block); return _exits.append(allocator(), block);
} }
ASMJIT_FORCE_INLINE RAInst* newRAInst(RABlock* block, RATiedFlags flags, uint32_t tiedRegCount, const RARegMask& clobberedRegs) noexcept { ASMJIT_FORCE_INLINE RAInst* newRAInst(RABlock* block, InstRWFlags instRWFlags, RATiedFlags flags, uint32_t tiedRegCount, const RARegMask& clobberedRegs) noexcept {
void* p = zone()->alloc(RAInst::sizeOf(tiedRegCount)); void* p = zone()->alloc(RAInst::sizeOf(tiedRegCount));
if (ASMJIT_UNLIKELY(!p)) if (ASMJIT_UNLIKELY(!p))
return nullptr; return nullptr;
return new(p) RAInst(block, flags, tiedRegCount, clobberedRegs); return new(p) RAInst(block, instRWFlags, flags, tiedRegCount, clobberedRegs);
} }
ASMJIT_FORCE_INLINE Error assignRAInst(BaseNode* node, RABlock* block, RAInstBuilder& ib) noexcept { ASMJIT_FORCE_INLINE Error assignRAInst(BaseNode* node, RABlock* block, RAInstBuilder& ib) noexcept {
uint32_t tiedRegCount = ib.tiedRegCount(); uint32_t tiedRegCount = ib.tiedRegCount();
RAInst* raInst = newRAInst(block, ib.aggregatedFlags(), tiedRegCount, ib._clobbered); RAInst* raInst = newRAInst(block, ib.instRWFlags(), ib.aggregatedFlags(), tiedRegCount, ib._clobbered);
if (ASMJIT_UNLIKELY(!raInst)) if (ASMJIT_UNLIKELY(!raInst))
return DebugUtils::errored(kErrorOutOfMemory); return DebugUtils::errored(kErrorOutOfMemory);
......
...@@ -30,7 +30,7 @@ static inline uint32_t getXmmMovInst(const FuncFrame& frame) { ...@@ -30,7 +30,7 @@ static inline uint32_t getXmmMovInst(const FuncFrame& frame) {
: (avx ? Inst::kIdVmovups : Inst::kIdMovups); : (avx ? Inst::kIdVmovups : Inst::kIdMovups);
} }
//! Converts `size` to a 'kmov?' instructio. //! Converts `size` to a 'kmov?' instruction.
static inline uint32_t kmovInstFromSize(uint32_t size) noexcept { static inline uint32_t kmovInstFromSize(uint32_t size) noexcept {
switch (size) { switch (size) {
case 1: return Inst::kIdKmovb; case 1: return Inst::kIdKmovb;
......
...@@ -606,7 +606,7 @@ namespace Inst { ...@@ -606,7 +606,7 @@ namespace Inst {
kIdPaddusb, //!< Instruction 'paddusb' {MMX|SSE2}. kIdPaddusb, //!< Instruction 'paddusb' {MMX|SSE2}.
kIdPaddusw, //!< Instruction 'paddusw' {MMX|SSE2}. kIdPaddusw, //!< Instruction 'paddusw' {MMX|SSE2}.
kIdPaddw, //!< Instruction 'paddw' {MMX|SSE2}. kIdPaddw, //!< Instruction 'paddw' {MMX|SSE2}.
kIdPalignr, //!< Instruction 'palignr' {SSSE3}. kIdPalignr, //!< Instruction 'palignr' {SSE3}.
kIdPand, //!< Instruction 'pand' {MMX|SSE2}. kIdPand, //!< Instruction 'pand' {MMX|SSE2}.
kIdPandn, //!< Instruction 'pandn' {MMX|SSE2}. kIdPandn, //!< Instruction 'pandn' {MMX|SSE2}.
kIdPause, //!< Instruction 'pause'. kIdPause, //!< Instruction 'pause'.
......
...@@ -776,6 +776,15 @@ static ASMJIT_FORCE_INLINE Error rwHandleAVX512(const BaseInst& inst, const Inst ...@@ -776,6 +776,15 @@ static ASMJIT_FORCE_INLINE Error rwHandleAVX512(const BaseInst& inst, const Inst
return kErrorOk; return kErrorOk;
} }
static ASMJIT_FORCE_INLINE bool hasSameRegType(const BaseReg* regs, size_t opCount) noexcept {
ASMJIT_ASSERT(opCount > 0);
RegType regType = regs[0].type();
for (size_t i = 1; i < opCount; i++)
if (regs[i].type() != regType)
return false;
return true;
}
Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* operands, size_t opCount, InstRWInfo* out) noexcept { Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* operands, size_t opCount, InstRWInfo* out) noexcept {
// Only called when `arch` matches X86 family. // Only called when `arch` matches X86 family.
ASMJIT_ASSERT(Environment::isFamilyX86(arch)); ASMJIT_ASSERT(Environment::isFamilyX86(arch));
...@@ -801,13 +810,14 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* ...@@ -801,13 +810,14 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_*
: InstDB::rwInfoB[InstDB::rwInfoIndexB[instId]]; : InstDB::rwInfoB[InstDB::rwInfoIndexB[instId]];
const InstDB::RWInfoRm& instRmInfo = InstDB::rwInfoRm[instRwInfo.rmInfo]; const InstDB::RWInfoRm& instRmInfo = InstDB::rwInfoRm[instRwInfo.rmInfo];
out->_instFlags = 0; out->_instFlags = InstDB::_instFlagsTable[additionalInfo._instFlagsIndex];
out->_opCount = uint8_t(opCount); out->_opCount = uint8_t(opCount);
out->_rmFeature = instRmInfo.rmFeature; out->_rmFeature = instRmInfo.rmFeature;
out->_extraReg.reset(); out->_extraReg.reset();
out->_readFlags = CpuRWFlags(rwFlags.readFlags); out->_readFlags = CpuRWFlags(rwFlags.readFlags);
out->_writeFlags = CpuRWFlags(rwFlags.writeFlags); out->_writeFlags = CpuRWFlags(rwFlags.writeFlags);
uint32_t opTypeMask = 0u;
uint32_t nativeGpSize = Environment::registerSizeFromArch(arch); uint32_t nativeGpSize = Environment::registerSizeFromArch(arch);
constexpr OpRWFlags R = OpRWFlags::kRead; constexpr OpRWFlags R = OpRWFlags::kRead;
...@@ -827,6 +837,8 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* ...@@ -827,6 +837,8 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_*
const Operand_& srcOp = operands[i]; const Operand_& srcOp = operands[i];
const InstDB::RWInfoOp& rwOpData = InstDB::rwInfoOp[instRwInfo.opInfoIndex[i]]; const InstDB::RWInfoOp& rwOpData = InstDB::rwInfoOp[instRwInfo.opInfoIndex[i]];
opTypeMask |= Support::bitMask(srcOp.opType());
if (!srcOp.isRegOrMem()) { if (!srcOp.isRegOrMem()) {
op.reset(); op.reset();
continue; continue;
...@@ -878,8 +890,23 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* ...@@ -878,8 +890,23 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_*
} }
} }
if (instRmInfo.flags & (InstDB::RWInfoRm::kFlagPextrw | InstDB::RWInfoRm::kFlagFeatureIfRMI)) { // Only keep kMovOp if the instruction is actually register to register move of the same kind.
if (instRmInfo.flags & InstDB::RWInfoRm::kFlagPextrw) { if (out->hasInstFlag(InstRWFlags::kMovOp)) {
if (!(opCount >= 2 && opTypeMask == Support::bitMask(OperandType::kReg) && hasSameRegType(reinterpret_cast<const BaseReg*>(operands), opCount)))
out->_instFlags &= ~InstRWFlags::kMovOp;
}
// Special cases require more logic.
if (instRmInfo.flags & (InstDB::RWInfoRm::kFlagMovssMovsd | InstDB::RWInfoRm::kFlagPextrw | InstDB::RWInfoRm::kFlagFeatureIfRMI)) {
if (instRmInfo.flags & InstDB::RWInfoRm::kFlagMovssMovsd) {
if (opCount == 2) {
if (operands[0].isReg() && operands[1].isReg()) {
// Doesn't zero extend the destination.
out->_operands[0]._extendByteMask = 0;
}
}
}
else if (instRmInfo.flags & InstDB::RWInfoRm::kFlagPextrw) {
if (opCount == 3 && Reg::isMm(operands[1])) { if (opCount == 3 && Reg::isMm(operands[1])) {
out->_rmFeature = 0; out->_rmFeature = 0;
rmOpsMask = 0; rmOpsMask = 0;
...@@ -930,6 +957,9 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* ...@@ -930,6 +957,9 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_*
// used to move between GP, segment, control and debug registers. Moving between GP registers also allow to // used to move between GP, segment, control and debug registers. Moving between GP registers also allow to
// use memory operand. // use memory operand.
// We will again set the flag if it's actually a move from GP to GP register, otherwise this flag cannot be set.
out->_instFlags &= ~InstRWFlags::kMovOp;
if (opCount == 2) { if (opCount == 2) {
if (operands[0].isReg() && operands[1].isReg()) { if (operands[0].isReg() && operands[1].isReg()) {
const Reg& o0 = operands[0].as<Reg>(); const Reg& o0 = operands[0].as<Reg>();
...@@ -940,6 +970,7 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* ...@@ -940,6 +970,7 @@ Error InstInternal::queryRWInfo(Arch arch, const BaseInst& inst, const Operand_*
out->_operands[1].reset(R | RegM, operands[1].size()); out->_operands[1].reset(R | RegM, operands[1].size());
rwZeroExtendGp(out->_operands[0], operands[0].as<Gp>(), nativeGpSize); rwZeroExtendGp(out->_operands[0], operands[0].as<Gp>(), nativeGpSize);
out->_instFlags |= InstRWFlags::kMovOp;
return kErrorOk; return kErrorOk;
} }
...@@ -1647,10 +1678,10 @@ UNIT(x86_inst_api_rm_feature) { ...@@ -1647,10 +1678,10 @@ UNIT(x86_inst_api_rm_feature) {
InstRWInfo rwi; InstRWInfo rwi;
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdPextrw, InstOptions::kNone, eax, mm1, imm(1)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdPextrw, InstOptions::kNone, eax, mm1, imm(1));
EXPECT(rwi._rmFeature == 0); EXPECT(rwi.rmFeature() == 0);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdPextrw, InstOptions::kNone, eax, xmm1, imm(1)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdPextrw, InstOptions::kNone, eax, xmm1, imm(1));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kSSE4_1); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kSSE4_1);
} }
INFO("Verifying whether RM/feature is reported correctly for AVX512 shift instructions"); INFO("Verifying whether RM/feature is reported correctly for AVX512 shift instructions");
...@@ -1658,40 +1689,40 @@ UNIT(x86_inst_api_rm_feature) { ...@@ -1658,40 +1689,40 @@ UNIT(x86_inst_api_rm_feature) {
InstRWInfo rwi; InstRWInfo rwi;
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpslld, InstOptions::kNone, xmm1, xmm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpslld, InstOptions::kNone, xmm1, xmm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_F); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_F);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsllq, InstOptions::kNone, ymm1, ymm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsllq, InstOptions::kNone, ymm1, ymm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_F); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_F);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrad, InstOptions::kNone, xmm1, xmm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrad, InstOptions::kNone, xmm1, xmm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_F); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_F);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrld, InstOptions::kNone, ymm1, ymm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrld, InstOptions::kNone, ymm1, ymm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_F); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_F);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrlq, InstOptions::kNone, xmm1, xmm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrlq, InstOptions::kNone, xmm1, xmm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_F); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_F);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpslldq, InstOptions::kNone, xmm1, xmm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpslldq, InstOptions::kNone, xmm1, xmm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_BW); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_BW);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsllw, InstOptions::kNone, ymm1, ymm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsllw, InstOptions::kNone, ymm1, ymm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_BW); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_BW);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsraw, InstOptions::kNone, xmm1, xmm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsraw, InstOptions::kNone, xmm1, xmm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_BW); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_BW);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrldq, InstOptions::kNone, ymm1, ymm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrldq, InstOptions::kNone, ymm1, ymm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_BW); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_BW);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrlw, InstOptions::kNone, xmm1, xmm2, imm(8)); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsrlw, InstOptions::kNone, xmm1, xmm2, imm(8));
EXPECT(rwi._rmFeature == CpuFeatures::X86::kAVX512_BW); EXPECT(rwi.rmFeature() == CpuFeatures::X86::kAVX512_BW);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpslld, InstOptions::kNone, xmm1, xmm2, xmm3); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpslld, InstOptions::kNone, xmm1, xmm2, xmm3);
EXPECT(rwi._rmFeature == 0); EXPECT(rwi.rmFeature() == 0);
queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsllw, InstOptions::kNone, xmm1, xmm2, xmm3); queryRWInfoSimple(&rwi, Arch::kX64, Inst::kIdVpsllw, InstOptions::kNone, xmm1, xmm2, xmm3);
EXPECT(rwi._rmFeature == 0); EXPECT(rwi.rmFeature() == 0);
} }
} }
#endif #endif
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -461,7 +461,7 @@ struct InstInfo { ...@@ -461,7 +461,7 @@ struct InstInfo {
//! \name Accessors //! \name Accessors
//! \{ //! \{
//! Returns common information, see `CommonInfo`. //! Returns common information, see \ref CommonInfo.
inline const CommonInfo& commonInfo() const noexcept { return _commonInfoTable[_commonInfoIndex]; } inline const CommonInfo& commonInfo() const noexcept { return _commonInfoTable[_commonInfoIndex]; }
//! Returns instruction flags, see \ref Flags. //! Returns instruction flags, see \ref Flags.
......
...@@ -189,12 +189,12 @@ enum EncodingId : uint32_t { ...@@ -189,12 +189,12 @@ enum EncodingId : uint32_t {
//! Additional information table, provides CPU extensions required to execute an instruction and RW flags. //! Additional information table, provides CPU extensions required to execute an instruction and RW flags.
struct AdditionalInfo { struct AdditionalInfo {
//! Features vector. //! Index to `_instFlagsTable`.
uint8_t _features[6]; uint8_t _instFlagsIndex;
//! Index to `_rwFlagsTable`. //! Index to `_rwFlagsTable`.
uint8_t _rwFlagsIndex; uint8_t _rwFlagsIndex;
//! Reserved for future use. //! Features vector.
uint8_t _reserved; uint8_t _features[6];
inline const uint8_t* featuresBegin() const noexcept { return _features; } inline const uint8_t* featuresBegin() const noexcept { return _features; }
inline const uint8_t* featuresEnd() const noexcept { return _features + ASMJIT_ARRAY_SIZE(_features); } inline const uint8_t* featuresEnd() const noexcept { return _features + ASMJIT_ARRAY_SIZE(_features); }
...@@ -260,8 +260,12 @@ struct RWInfoRm { ...@@ -260,8 +260,12 @@ struct RWInfoRm {
enum Flags : uint8_t { enum Flags : uint8_t {
kFlagAmbiguous = 0x01, kFlagAmbiguous = 0x01,
//! Special semantics for PEXTRW - memory operand can only be used with SSE4.1 instruction and it's forbidden in MMX.
kFlagPextrw = 0x02, kFlagPextrw = 0x02,
kFlagFeatureIfRMI = 0x04 //! Special semantics for MOVSS and MOVSD - doesn't zero extend the destination if the operation is a reg to reg move.
kFlagMovssMovsd = 0x04,
//! Special semantics for AVX shift instructions that do not provide reg/mem in AVX/AVX2 mode (AVX-512 is required).
kFlagFeatureIfRMI = 0x08
}; };
uint8_t category; uint8_t category;
...@@ -285,6 +289,7 @@ extern const RWInfo rwInfoB[]; ...@@ -285,6 +289,7 @@ extern const RWInfo rwInfoB[];
extern const RWInfoOp rwInfoOp[]; extern const RWInfoOp rwInfoOp[];
extern const RWInfoRm rwInfoRm[]; extern const RWInfoRm rwInfoRm[];
extern const RWFlagsInfoTable _rwFlagsInfoTable[]; extern const RWFlagsInfoTable _rwFlagsInfoTable[];
extern const InstRWFlags _instFlagsTable[];
extern const uint32_t _mainOpcodeTable[]; extern const uint32_t _mainOpcodeTable[];
extern const uint32_t _altOpcodeTable[]; extern const uint32_t _altOpcodeTable[];
......
...@@ -126,6 +126,12 @@ Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& cf, RAInstBuilder& i ...@@ -126,6 +126,12 @@ Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& cf, RAInstBuilder& i
bool hasGpbHiConstraint = false; bool hasGpbHiConstraint = false;
uint32_t singleRegOps = 0; uint32_t singleRegOps = 0;
// Copy instruction RW flags to instruction builder except kMovOp, which is propagated manually later.
ib.addInstRWFlags(rwInfo.instFlags() & ~InstRWFlags::kMovOp);
// Mask of all operand types used by the instruction - can be used as an optimization later.
uint32_t opTypesMask = 0u;
if (opCount) { if (opCount) {
// The mask is for all registers, but we are mostly interested in AVX-512 registers at the moment. The mask // The mask is for all registers, but we are mostly interested in AVX-512 registers at the moment. The mask
// will be combined with all available registers of the Compiler at the end so we it never use more registers // will be combined with all available registers of the Compiler at the end so we it never use more registers
...@@ -167,6 +173,8 @@ Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& cf, RAInstBuilder& i ...@@ -167,6 +173,8 @@ Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& cf, RAInstBuilder& i
const Operand& op = opArray[i]; const Operand& op = opArray[i];
const OpRWInfo& opRwInfo = rwInfo.operand(i); const OpRWInfo& opRwInfo = rwInfo.operand(i);
opTypesMask |= 1u << uint32_t(op.opType());
if (op.isReg()) { if (op.isReg()) {
// Register Operand // Register Operand
// ---------------- // ----------------
...@@ -394,6 +402,24 @@ Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& cf, RAInstBuilder& i ...@@ -394,6 +402,24 @@ Error RACFGBuilder::onInst(InstNode* inst, InstControlFlow& cf, RAInstBuilder& i
} }
} }
// If this instruction has move semantics then check whether it could be eliminated if all virtual registers
// are allocated into the same register. Take into account the virtual size of the destination register as that's
// more important than a physical register size in this case.
if (rwInfo.hasInstFlag(InstRWFlags::kMovOp) && !inst->hasExtraReg() && Support::bitTest(opTypesMask, uint32_t(OperandType::kReg))) {
// AVX+ move instructions have 3 operand form - the first two operands must be the same to guarantee move semantics.
if (opCount == 2 || (opCount == 3 && opArray[0] == opArray[1])) {
uint32_t vIndex = Operand::virtIdToIndex(opArray[0].as<Reg>().id());
if (vIndex < Operand::kVirtIdCount) {
const VirtReg* vReg = _cc->virtRegByIndex(vIndex);
const OpRWInfo& opRwInfo = rwInfo.operand(0);
uint64_t remainingByteMask = vReg->workReg()->regByteMask() & ~opRwInfo.writeByteMask();
if (remainingByteMask == 0u || (remainingByteMask & opRwInfo.extendByteMask()) == 0)
ib.addInstRWFlags(InstRWFlags::kMovOp);
}
}
}
// Handle X86 constraints. // Handle X86 constraints.
if (hasGpbHiConstraint) { if (hasGpbHiConstraint) {
for (RATiedReg& tiedReg : ib) { for (RATiedReg& tiedReg : ib) {
...@@ -1251,6 +1277,10 @@ ASMJIT_FAVOR_SPEED Error X86RAPass::_rewrite(BaseNode* first, BaseNode* stop) no ...@@ -1251,6 +1277,10 @@ ASMJIT_FAVOR_SPEED Error X86RAPass::_rewrite(BaseNode* first, BaseNode* stop) no
// Rewrite virtual registers into physical registers. // Rewrite virtual registers into physical registers.
if (raInst) { if (raInst) {
// This data is allocated by Zone passed to `runOnFunction()`, which will be reset after the RA pass finishes.
// So reset this data to prevent having a dead pointer after the RA pass is complete.
node->resetPassData();
// If the instruction contains pass data (raInst) then it was a subject for register allocation and must be // If the instruction contains pass data (raInst) then it was a subject for register allocation and must be
// rewritten to use physical regs. // rewritten to use physical regs.
RATiedReg* tiedRegs = raInst->tiedRegs(); RATiedReg* tiedRegs = raInst->tiedRegs();
...@@ -1274,16 +1304,25 @@ ASMJIT_FAVOR_SPEED Error X86RAPass::_rewrite(BaseNode* first, BaseNode* stop) no ...@@ -1274,16 +1304,25 @@ ASMJIT_FAVOR_SPEED Error X86RAPass::_rewrite(BaseNode* first, BaseNode* stop) no
} }
} }
// Transform VEX instruction to EVEX when necessary.
if (raInst->isTransformable()) { if (raInst->isTransformable()) {
if (maxRegId > 15) { if (maxRegId > 15) {
// Transform VEX instruction to EVEX.
inst->setId(transformVexToEvex(inst->id())); inst->setId(transformVexToEvex(inst->id()));
} }
} }
// This data is allocated by Zone passed to `runOnFunction()`, which will be reset after the RA pass finishes. // Remove moves that do not do anything.
// So reset this data to prevent having a dead pointer after the RA pass is complete. //
node->resetPassData(); // Usually these moves are inserted during code generation and originally they used different registers. If RA
// allocated these into the same register such redundant mov would appear.
if (raInst->hasInstRWFlag(InstRWFlags::kMovOp) && !inst->hasExtraReg()) {
if (inst->opCount() == 2) {
if (inst->op(0) == inst->op(1)) {
cc()->removeNode(node);
goto Next;
}
}
}
if (ASMJIT_UNLIKELY(node->type() != NodeType::kInst)) { if (ASMJIT_UNLIKELY(node->type() != NodeType::kInst)) {
// FuncRet terminates the flow, it must either be removed if the exit label is next to it (optimization) or // FuncRet terminates the flow, it must either be removed if the exit label is next to it (optimization) or
...@@ -1327,6 +1366,7 @@ ASMJIT_FAVOR_SPEED Error X86RAPass::_rewrite(BaseNode* first, BaseNode* stop) no ...@@ -1327,6 +1366,7 @@ ASMJIT_FAVOR_SPEED Error X86RAPass::_rewrite(BaseNode* first, BaseNode* stop) no
} }
} }
Next:
node = next; node = next;
} }
......
#ifndef LEPTON_VECTOR_EXPRESSION_H_
#define LEPTON_VECTOR_EXPRESSION_H_
/* -------------------------------------------------------------------------- *
* Lepton *
* -------------------------------------------------------------------------- *
* This is part of the Lepton expression parser originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "ExpressionTreeNode.h"
#include "windowsIncludes.h"
#include <array>
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
#ifdef LEPTON_USE_JIT
#if defined(__ARM__) || defined(__ARM64__)
#include "asmjit/a64.h"
#else
#include "asmjit/x86.h"
#endif
#endif
namespace Lepton {
class Operation;
class ParsedExpression;
/**
* A CompiledVectorExpression is a highly optimized representation of an expression for cases when you want to evaluate
* it many times as quickly as possible. It is similar to CompiledExpression, with the extra feature that it uses the CPU's
* vector unit (AVX on x86, NEON on ARM) to evaluate the expression for multiple sets of arguments at once. It also differs
* from CompiledExpression and ParsedExpression in using single precision rather than double precision to evaluate the expression.
* You should treat it as an opaque object; none of the internal representation is visible.
*
* A CompiledVectorExpression is created by calling createCompiledVectorExpression() on a ParsedExpression. When you create
* it, you must specify the width of the vectors on which to compute the expression. The allowed widths depend on the type of
* CPU it is running on. 4 is always allowed, and 8 is allowed on x86 processors with AVX. Call getAllowedWidths() to query
* the allowed values.
*
* WARNING: CompiledVectorExpression is NOT thread safe. You should never access a CompiledVectorExpression from two threads at
* the same time.
*/
class LEPTON_EXPORT CompiledVectorExpression {
public:
CompiledVectorExpression();
CompiledVectorExpression(const CompiledVectorExpression& expression);
~CompiledVectorExpression();
CompiledVectorExpression& operator=(const CompiledVectorExpression& expression);
/**
* Get the width of the vectors on which the expression is computed.
*/
int getWidth() const;
/**
* Get the names of all variables used by this expression.
*/
const std::set<std::string>& getVariables() const;
/**
* Get a pointer to the memory location where the value of a particular variable is stored. This can be used
* to set the value of the variable before calling evaluate().
*
* @param name the name of the variable to query
* @return a pointer to N floating point values, where N is the vector width
*/
float* getVariablePointer(const std::string& name);
/**
* You can optionally specify the memory locations from which the values of variables should be read.
* This is useful, for example, when several expressions all use the same variable. You can then set
* the value of that variable in one place, and it will be seen by all of them. The location should
* be a pointer to N floating point values, where N is the vector width.
*/
void setVariableLocations(std::map<std::string, float*>& variableLocations);
/**
* Evaluate the expression. The values of all variables should have been set before calling this.
*
* @return a pointer to N floating point values, where N is the vector width
*/
const float* evaluate() const;
/**
* Get the list of vector widths that are supported on the current processor.
*/
static const std::vector<int>& getAllowedWidths();
private:
friend class ParsedExpression;
CompiledVectorExpression(const ParsedExpression& expression, int width);
void compileExpression(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps, int& workspaceSize);
int findTempIndex(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
int width;
std::map<std::string, float*> variablePointers;
std::vector<std::pair<float*, float*> > variablesToCopy;
std::vector<std::vector<int> > arguments;
std::vector<int> target;
std::vector<Operation*> operation;
std::map<std::string, int> variableIndices;
std::set<std::string> variableNames;
mutable std::vector<float> workspace;
mutable std::vector<double> argValues;
std::map<std::string, double> dummyVariables;
void (*jitCode)();
#ifdef LEPTON_USE_JIT
void findPowerGroups(std::vector<std::vector<int> >& groups, std::vector<std::vector<int> >& groupPowers, std::vector<int>& stepGroup);
void generateJitCode();
#if defined(__ARM__) || defined(__ARM64__)
void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, float (*function)(float));
void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, float (*function)(float, float));
#else
void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg, float (*function)(float));
void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg1, asmjit::x86::Ymm& arg2, float (*function)(float, float));
#endif
std::vector<float> constants;
asmjit::JitRuntime runtime;
#endif
};
} // namespace Lepton
#endif /*LEPTON_VECTOR_EXPRESSION_H_*/
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment