Unverified Commit aafb8b5b authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Vectorize custom expressions on CPU (#3552)

* Began implementing vectorization of Lepton expressions

* Tests for vector expressions

* Implemented CompiledVectorExpression for x86

* Bug fix

* Optimized select() on ARM

* Optimized select() on x86

* CompiledVectorExpression supports AVX

* Bug fix

* Updated docs

* Use VEX encoded instructions for CompiledExpression

* Optimized min() and max() on x86

* Optimized min() and max() on ARM

* Fixed compilation error

* Upgrade AsmJit
parent 6fb1c8a4
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009=2021 Stanford University and the Authors. * * Portions copyright (c) 2009-2022 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -41,6 +41,7 @@ namespace Lepton { ...@@ -41,6 +41,7 @@ namespace Lepton {
class CompiledExpression; class CompiledExpression;
class ExpressionProgram; class ExpressionProgram;
class CompiledVectorExpression;
/** /**
* This class represents the result of parsing an expression. It provides methods for working with the * This class represents the result of parsing an expression. It provides methods for working with the
...@@ -102,6 +103,16 @@ public: ...@@ -102,6 +103,16 @@ public:
* Create a CompiledExpression that represents the same calculation as this expression. * Create a CompiledExpression that represents the same calculation as this expression.
*/ */
CompiledExpression createCompiledExpression() const; CompiledExpression createCompiledExpression() const;
/**
* Create a CompiledVectorExpression that allows the expression to be evaluated efficiently
* using the CPU's vector unit.
*
* @param width the width of the vectors to evaluate it on. The allowed values
* depend on the CPU. 4 is always allowed, and 8 is allowed on
* x86 processors with AVX. Call CompiledVectorExpression::getAllowedWidths()
* to query the allowed widths on the current processor.
*/
CompiledVectorExpression createCompiledVectorExpression(int width) const;
/** /**
* Create a new ParsedExpression which is identical to this one, except that the names of some * Create a new ParsedExpression which is identical to this one, except that the names of some
* variables have been changed. * variables have been changed.
......
...@@ -151,7 +151,7 @@ void CompiledExpression::setVariableLocations(map<string, double*>& variableLoca ...@@ -151,7 +151,7 @@ void CompiledExpression::setVariableLocations(map<string, double*>& variableLoca
if (workspace.size() > 0) if (workspace.size() > 0)
generateJitCode(); generateJitCode();
#else #endif
// Make a list of all variables we will need to copy before evaluating the expression. // Make a list of all variables we will need to copy before evaluating the expression.
variablesToCopy.clear(); variablesToCopy.clear();
...@@ -160,13 +160,11 @@ void CompiledExpression::setVariableLocations(map<string, double*>& variableLoca ...@@ -160,13 +160,11 @@ void CompiledExpression::setVariableLocations(map<string, double*>& variableLoca
if (pointer != variablePointers.end()) if (pointer != variablePointers.end())
variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second)); variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second));
} }
#endif
} }
double CompiledExpression::evaluate() const { double CompiledExpression::evaluate() const {
#ifdef LEPTON_USE_JIT if (jitCode)
return jitCode(); return jitCode();
#else
for (int i = 0; i < variablesToCopy.size(); i++) for (int i = 0; i < variablesToCopy.size(); i++)
*variablesToCopy[i].first = *variablesToCopy[i].second; *variablesToCopy[i].first = *variablesToCopy[i].second;
...@@ -183,7 +181,6 @@ double CompiledExpression::evaluate() const { ...@@ -183,7 +181,6 @@ double CompiledExpression::evaluate() const {
} }
} }
return workspace[workspace.size()-1]; return workspace[workspace.size()-1];
#endif
} }
#ifdef LEPTON_USE_JIT #ifdef LEPTON_USE_JIT
...@@ -458,14 +455,24 @@ void CompiledExpression::generateJitCode() { ...@@ -458,14 +455,24 @@ void CompiledExpression::generateJitCode() {
case Operation::POWER_CONSTANT: case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
break; break;
case Operation::MIN:
c.fmin(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MAX:
c.fmax(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::ABS: case Operation::ABS:
c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]); c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::FLOOR: case Operation::FLOOR:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor); c.frintm(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::CEIL: case Operation::CEIL:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil); c.frintp(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::SELECT:
c.fcmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0));
c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]);
break; break;
default: default:
// Just invoke evaluateOperation(). // Just invoke evaluateOperation().
...@@ -507,10 +514,14 @@ void CompiledExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, ar ...@@ -507,10 +514,14 @@ void CompiledExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, ar
} }
#else #else
void CompiledExpression::generateJitCode() { void CompiledExpression::generateJitCode() {
const CpuInfo& cpu = CpuInfo::host();
if (!cpu.hasFeature(CpuFeatures::X86::kAVX))
return;
CodeHolder code; CodeHolder code;
code.init(runtime.environment()); code.init(runtime.environment());
x86::Compiler c(&code); x86::Compiler c(&code);
c.addFunc(FuncSignatureT<double>()); FuncNode* funcNode = c.addFunc(FuncSignatureT<double>());
funcNode->frame().setAvxEnabled();
vector<x86::Xmm> workspaceVar(workspace.size()); vector<x86::Xmm> workspaceVar(workspace.size());
for (int i = 0; i < (int) workspaceVar.size(); i++) for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newXmmSd(); workspaceVar[i] = c.newXmmSd();
...@@ -522,11 +533,11 @@ void CompiledExpression::generateJitCode() { ...@@ -522,11 +533,11 @@ void CompiledExpression::generateJitCode() {
// Load the arguments into variables. // Load the arguments into variables.
x86::Gp variablePointer = c.newIntPtr();
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter); map<string, int>::iterator index = variableIndices.find(*iter);
x86::Gp variablePointer = c.newIntPtr();
c.mov(variablePointer, imm(&getVariableReference(index->first))); c.mov(variablePointer, imm(&getVariableReference(index->first)));
c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); c.vmovsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
} }
// Make a list of all constants that will be needed for evaluation. // Make a list of all constants that will be needed for evaluation.
...@@ -549,6 +560,10 @@ void CompiledExpression::generateJitCode() { ...@@ -549,6 +560,10 @@ void CompiledExpression::generateJitCode() {
value = 1.0; value = 1.0;
else if (op.getId() == Operation::DELTA) else if (op.getId() == Operation::DELTA)
value = 1.0; value = 1.0;
else if (op.getId() == Operation::ABS) {
long long mask = 0x7FFFFFFFFFFFFFFF;
value = *reinterpret_cast<double*>(&mask);
}
else if (op.getId() == Operation::POWER_CONSTANT) { else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1) if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&>(op).getValue(); value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
...@@ -579,7 +594,7 @@ void CompiledExpression::generateJitCode() { ...@@ -579,7 +594,7 @@ void CompiledExpression::generateJitCode() {
c.mov(constantsPointer, imm(&constants[0])); c.mov(constantsPointer, imm(&constants[0]));
for (int i = 0; i < (int) constants.size(); i++) { for (int i = 0; i < (int) constants.size(); i++) {
constantVar[i] = c.newXmmSd(); constantVar[i] = c.newXmmSd();
c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); c.vmovsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0));
} }
} }
...@@ -598,10 +613,9 @@ void CompiledExpression::generateJitCode() { ...@@ -598,10 +613,9 @@ void CompiledExpression::generateJitCode() {
vector<int>& powers = groupPowers[stepGroup[step]]; vector<int>& powers = groupPowers[stepGroup[step]];
x86::Xmm multiplier = c.newXmmSd(); x86::Xmm multiplier = c.newXmmSd();
if (powers[0] > 0) if (powers[0] > 0)
c.movsd(multiplier, workspaceVar[arguments[step][0]]); c.vmovsd(multiplier, workspaceVar[arguments[step][0]], workspaceVar[arguments[step][0]]);
else { else {
c.movsd(multiplier, constantVar[operationConstantIndex[step]]); c.vdivsd(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
c.divsd(multiplier, workspaceVar[arguments[step][0]]);
for (int i = 0; i < powers.size(); i++) for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i]; powers[i] = -powers[i];
} }
...@@ -612,9 +626,9 @@ void CompiledExpression::generateJitCode() { ...@@ -612,9 +626,9 @@ void CompiledExpression::generateJitCode() {
for (int i = 0; i < group.size(); i++) { for (int i = 0; i < group.size(); i++) {
if (powers[i]%2 == 1) { if (powers[i]%2 == 1) {
if (!hasAssigned[i]) if (!hasAssigned[i])
c.movsd(workspaceVar[target[group[i]]], multiplier); c.vmovsd(workspaceVar[target[group[i]]], multiplier, multiplier);
else else
c.mulsd(workspaceVar[target[group[i]]], multiplier); c.vmulsd(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
hasAssigned[i] = true; hasAssigned[i] = true;
} }
powers[i] >>= 1; powers[i] >>= 1;
...@@ -622,7 +636,7 @@ void CompiledExpression::generateJitCode() { ...@@ -622,7 +636,7 @@ void CompiledExpression::generateJitCode() {
done = false; done = false;
} }
if (!done) if (!done)
c.mulsd(multiplier, multiplier); c.vmulsd(multiplier, multiplier, multiplier);
} }
for (int step : group) for (int step : group)
hasComputedPower[step] = true; hasComputedPower[step] = true;
...@@ -644,33 +658,29 @@ void CompiledExpression::generateJitCode() { ...@@ -644,33 +658,29 @@ void CompiledExpression::generateJitCode() {
switch (op.getId()) { switch (op.getId()) {
case Operation::CONSTANT: case Operation::CONSTANT:
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); c.vmovsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::ADD: case Operation::ADD:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]);
break; break;
case Operation::SUBTRACT: case Operation::SUBTRACT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vsubsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]);
break; break;
case Operation::MULTIPLY: case Operation::MULTIPLY:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]);
break; break;
case Operation::DIVIDE: case Operation::DIVIDE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vdivsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]);
break; break;
case Operation::POWER: case Operation::POWER:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow);
break; break;
case Operation::NEGATE: case Operation::NEGATE:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vsubsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::SQRT: case Operation::SQRT:
c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vsqrtsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
break; break;
case Operation::EXP: case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
...@@ -709,53 +719,62 @@ void CompiledExpression::generateJitCode() { ...@@ -709,53 +719,62 @@ void CompiledExpression::generateJitCode() {
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
break; break;
case Operation::STEP: case Operation::STEP:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::DELTA: case Operation::DELTA:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::SQUARE: case Operation::SQUARE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::CUBE: case Operation::CUBE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vmulsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::RECIPROCAL: case Operation::RECIPROCAL:
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); c.vdivsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]);
c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
case Operation::ADD_CONSTANT: case Operation::ADD_CONSTANT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::MULTIPLY_CONSTANT: case Operation::MULTIPLY_CONSTANT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::POWER_CONSTANT: case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
break; break;
case Operation::MIN:
c.vminsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MAX:
c.vmaxsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::ABS: case Operation::ABS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs); c.vandpd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::FLOOR: case Operation::FLOOR:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor); c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(1));
break; break;
case Operation::CEIL: case Operation::CEIL:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil); c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(2));
break;
case Operation::SELECT:
{
x86::Xmm mask = c.newXmmSd();
c.vxorps(mask, mask, mask);
c.vcmpsd(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0
c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask);
break; break;
}
default: default:
// Just invoke evaluateOperation(). // Just invoke evaluateOperation().
for (int i = 0; i < (int) args.size(); i++) for (int i = 0; i < (int) args.size(); i++)
c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); c.vmovsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
x86::Gp fn = c.newIntPtr(); x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) evaluateOperation)); c.mov(fn, imm((void*) evaluateOperation));
InvokeNode* invoke; InvokeNode* invoke;
......
/* -------------------------------------------------------------------------- *
* Lepton *
* -------------------------------------------------------------------------- *
* This is part of the Lepton expression parser originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "lepton/CompiledVectorExpression.h"
#include "lepton/Operation.h"
#include "lepton/ParsedExpression.h"
#include <algorithm>
#include <utility>
using namespace Lepton;
using namespace std;
#ifdef LEPTON_USE_JIT
using namespace asmjit;
#endif
CompiledVectorExpression::CompiledVectorExpression() : jitCode(NULL) {
}
CompiledVectorExpression::CompiledVectorExpression(const ParsedExpression& expression, int width) : jitCode(NULL), width(width) {
const vector<int> allowedWidths = getAllowedWidths();
if (find(allowedWidths.begin(), allowedWidths.end(), width) == allowedWidths.end())
throw Exception("Unsupported width for vector expression: "+to_string(width));
ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized.
vector<pair<ExpressionTreeNode, int> > temps;
int workspaceSize = 0;
compileExpression(expr.getRootNode(), temps, workspaceSize);
workspace.resize(workspaceSize*width);
int maxArguments = 1;
for (int i = 0; i < (int) operation.size(); i++)
if (operation[i]->getNumArguments() > maxArguments)
maxArguments = operation[i]->getNumArguments();
argValues.resize(maxArguments);
#ifdef LEPTON_USE_JIT
generateJitCode();
#endif
}
CompiledVectorExpression::~CompiledVectorExpression() {
for (int i = 0; i < (int) operation.size(); i++)
if (operation[i] != NULL)
delete operation[i];
}
CompiledVectorExpression::CompiledVectorExpression(const CompiledVectorExpression& expression) : jitCode(NULL) {
*this = expression;
}
CompiledVectorExpression& CompiledVectorExpression::operator=(const CompiledVectorExpression& expression) {
arguments = expression.arguments;
width = expression.width;
target = expression.target;
variableIndices = expression.variableIndices;
variableNames = expression.variableNames;
workspace.resize(expression.workspace.size());
argValues.resize(expression.argValues.size());
operation.resize(expression.operation.size());
for (int i = 0; i < (int) operation.size(); i++)
operation[i] = expression.operation[i]->clone();
setVariableLocations(variablePointers);
return *this;
}
const vector<int>& CompiledVectorExpression::getAllowedWidths() {
static vector<int> widths;
if (widths.size() == 0) {
widths.push_back(4);
#ifdef LEPTON_USE_JIT
const CpuInfo& cpu = CpuInfo::host();
if (cpu.hasFeature(CpuFeatures::X86::kAVX))
widths.push_back(8);
#endif
}
return widths;
}
void CompiledVectorExpression::compileExpression(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps, int& workspaceSize) {
if (findTempIndex(node, temps) != -1)
return; // We have already processed a node identical to this one.
// Process the child nodes.
vector<int> args;
for (int i = 0; i < node.getChildren().size(); i++) {
compileExpression(node.getChildren()[i], temps, workspaceSize);
args.push_back(findTempIndex(node.getChildren()[i], temps));
}
// Process this node.
if (node.getOperation().getId() == Operation::VARIABLE) {
variableIndices[node.getOperation().getName()] = workspaceSize;
variableNames.insert(node.getOperation().getName());
}
else {
int stepIndex = (int) arguments.size();
arguments.push_back(vector<int>());
target.push_back(workspaceSize);
operation.push_back(node.getOperation().clone());
if (args.size() == 0)
arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there.
else {
// If the arguments are sequential, we can just pass a pointer to the first one.
bool sequential = true;
for (int i = 1; i < args.size(); i++)
if (args[i] != args[i - 1] + 1)
sequential = false;
if (sequential)
arguments[stepIndex].push_back(args[0]);
else
arguments[stepIndex] = args;
}
}
temps.push_back(make_pair(node, workspaceSize));
workspaceSize++;
}
int CompiledVectorExpression::findTempIndex(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return i;
return -1;
}
int CompiledVectorExpression::getWidth() const {
return width;
}
const set<string>& CompiledVectorExpression::getVariables() const {
return variableNames;
}
float* CompiledVectorExpression::getVariablePointer(const string& name) {
map<string, float*>::iterator pointer = variablePointers.find(name);
if (pointer != variablePointers.end())
return pointer->second;
map<string, int>::iterator index = variableIndices.find(name);
if (index == variableIndices.end())
throw Exception("getVariableReference: Unknown variable '" + name + "'");
return &workspace[index->second*width];
}
void CompiledVectorExpression::setVariableLocations(map<string, float*>& variableLocations) {
variablePointers = variableLocations;
#ifdef LEPTON_USE_JIT
// Rebuild the JIT code.
if (workspace.size() > 0)
generateJitCode();
#endif
// Make a list of all variables we will need to copy before evaluating the expression.
variablesToCopy.clear();
for (map<string, int>::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) {
map<string, float*>::iterator pointer = variablePointers.find(iter->first);
if (pointer != variablePointers.end())
variablesToCopy.push_back(make_pair(&workspace[iter->second*width], pointer->second));
}
}
const float* CompiledVectorExpression::evaluate() const {
if (jitCode) {
jitCode();
return &workspace[workspace.size()-width];
}
for (int i = 0; i < variablesToCopy.size(); i++)
for (int j = 0; j < width; j++)
variablesToCopy[i].first[j] = variablesToCopy[i].second[j];
// Loop over the operations and evaluate each one.
for (int step = 0; step < operation.size(); step++) {
const vector<int>& args = arguments[step];
if (args.size() == 1) {
for (int j = 0; j < width; j++) {
for (int i = 0; i < operation[step]->getNumArguments(); i++)
argValues[i] = workspace[(args[0]+i)*width+j];
workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables);
}
} else {
for (int j = 0; j < width; j++) {
for (int i = 0; i < args.size(); i++)
argValues[i] = workspace[args[i]*width+j];
workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables);
}
}
}
return &workspace[workspace.size()-width];
}
#ifdef LEPTON_USE_JIT
static double evaluateOperation(Operation* op, double* args) {
static map<string, double> dummyVariables;
return op->evaluate(args, dummyVariables);
}
void CompiledVectorExpression::findPowerGroups(vector<vector<int> >& groups, vector<vector<int> >& groupPowers, vector<int>& stepGroup) {
// Identify every step that raises an argument to an integer power.
vector<int> stepPower(operation.size(), 0);
vector<int> stepArg(operation.size(), -1);
for (int step = 0; step < operation.size(); step++) {
Operation& op = *operation[step];
int power = 0;
if (op.getId() == Operation::SQUARE)
power = 2;
else if (op.getId() == Operation::CUBE)
power = 3;
else if (op.getId() == Operation::POWER_CONSTANT) {
double realPower = dynamic_cast<const Operation::PowerConstant*> (&op)->getValue();
if (realPower == (int) realPower)
power = (int) realPower;
}
if (power != 0) {
stepPower[step] = power;
stepArg[step] = arguments[step][0];
}
}
// Find groups that operate on the same argument and whose powers have the same sign.
stepGroup.resize(operation.size(), -1);
for (int i = 0; i < operation.size(); i++) {
if (stepGroup[i] != -1)
continue;
vector<int> group, power;
for (int j = i; j < operation.size(); j++) {
if (stepArg[i] == stepArg[j] && stepPower[i] * stepPower[j] > 0) {
stepGroup[j] = groups.size();
group.push_back(j);
power.push_back(stepPower[j]);
}
}
groups.push_back(group);
groupPowers.push_back(power);
}
}
#if defined(__ARM__) || defined(__ARM64__)
void CompiledVectorExpression::generateJitCode() {
CodeHolder code;
code.init(runtime.environment());
a64::Compiler c(&code);
c.addFunc(FuncSignatureT<void>());
vector<arm::Vec> workspaceVar(workspace.size()/width);
for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newVecQ();
arm::Gp argsPointer = c.newIntPtr();
c.mov(argsPointer, imm(&argValues[0]));
vector<vector<int> > groups, groupPowers;
vector<int> stepGroup;
findPowerGroups(groups, groupPowers, stepGroup);
// Load the arguments into variables.
arm::Gp variablePointer = c.newIntPtr();
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter);
c.mov(variablePointer, imm(getVariablePointer(index->first)));
c.ldr(workspaceVar[index->second].s4(), arm::ptr(variablePointer, 0));
}
// Make a list of all constants that will be needed for evaluation.
vector<int> operationConstantIndex(operation.size(), -1);
for (int step = 0; step < (int) operation.size(); step++) {
// Find the constant value (if any) used by this operation.
Operation& op = *operation[step];
float value;
if (op.getId() == Operation::CONSTANT)
value = dynamic_cast<Operation::Constant&> (op).getValue();
else if (op.getId() == Operation::ADD_CONSTANT)
value = dynamic_cast<Operation::AddConstant&> (op).getValue();
else if (op.getId() == Operation::MULTIPLY_CONSTANT)
value = dynamic_cast<Operation::MultiplyConstant&> (op).getValue();
else if (op.getId() == Operation::RECIPROCAL)
value = 1.0;
else if (op.getId() == Operation::STEP)
value = 1.0;
else if (op.getId() == Operation::DELTA)
value = 1.0;
else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&> (op).getValue();
else
value = 1.0;
} else
continue;
// See if we already have a variable for this constant.
for (int i = 0; i < (int) constants.size(); i++)
if (value == constants[i]) {
operationConstantIndex[step] = i;
break;
}
if (operationConstantIndex[step] == -1) {
operationConstantIndex[step] = constants.size();
constants.push_back(value);
}
}
// Load constants into variables.
vector<arm::Vec> constantVar(constants.size());
if (constants.size() > 0) {
arm::Gp constantsPointer = c.newIntPtr();
for (int i = 0; i < (int) constants.size(); i++) {
c.mov(constantsPointer, imm(&constants[i]));
constantVar[i] = c.newVecQ();
c.ld1r(constantVar[i].s4(), arm::ptr(constantsPointer));
}
}
// Evaluate the operations.
vector<bool> hasComputedPower(operation.size(), false);
arm::Vec argReg = c.newVecS();
arm::Vec doubleArgReg = c.newVecD();
arm::Vec doubleResultReg = c.newVecD();
for (int step = 0; step < (int) operation.size(); step++) {
if (hasComputedPower[step])
continue;
// When one or more steps involve raising the same argument to multiple integer
// powers, we can compute them all together for efficiency.
if (stepGroup[step] != -1) {
vector<int>& group = groups[stepGroup[step]];
vector<int>& powers = groupPowers[stepGroup[step]];
arm::Vec multiplier = c.newVecQ();
if (powers[0] > 0)
c.mov(multiplier.s4(), workspaceVar[arguments[step][0]].s4());
else {
c.fdiv(multiplier.s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[arguments[step][0]].s4());
for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i];
}
vector<bool> hasAssigned(group.size(), false);
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < group.size(); i++) {
if (powers[i] % 2 == 1) {
if (!hasAssigned[i])
c.mov(workspaceVar[target[group[i]]].s4(), multiplier.s4());
else
c.fmul(workspaceVar[target[group[i]]].s4(), workspaceVar[target[group[i]]].s4(), multiplier.s4());
hasAssigned[i] = true;
}
powers[i] >>= 1;
if (powers[i] != 0)
done = false;
}
if (!done)
c.fmul(multiplier.s4(), multiplier.s4(), multiplier.s4());
}
for (int step : group)
hasComputedPower[step] = true;
continue;
}
// Evaluate the step.
Operation& op = *operation[step];
vector<int> args = arguments[step];
if (args.size() == 1) {
// One or more sequential arguments. Fill out the list.
for (int i = 1; i < op.getNumArguments(); i++)
args.push_back(args[0] + i);
}
// Generate instructions to execute this operation.
switch (op.getId()) {
case Operation::CONSTANT:
c.mov(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4());
break;
case Operation::ADD:
c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::SUBTRACT:
c.fsub(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::MULTIPLY:
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::DIVIDE:
c.fdiv(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::POWER:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf);
break;
case Operation::NEGATE:
c.fneg(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::SQRT:
c.fsqrt(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf);
break;
case Operation::LOG:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf);
break;
case Operation::SIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf);
break;
case Operation::COS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf);
break;
case Operation::TAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf);
break;
case Operation::ASIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf);
break;
case Operation::ACOS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf);
break;
case Operation::ATAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf);
break;
case Operation::ATAN2:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f);
break;
case Operation::SINH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf);
break;
case Operation::COSH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf);
break;
case Operation::TANH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf);
break;
case Operation::STEP:
c.cmge(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0));
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::DELTA:
c.cmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0));
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::SQUARE:
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::CUBE:
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4());
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::RECIPROCAL:
c.fdiv(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::ADD_CONSTANT:
c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4());
break;
case Operation::MULTIPLY_CONSTANT:
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4());
break;
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf);
break;
case Operation::MIN:
c.fmin(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::MAX:
c.fmax(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::ABS:
c.fabs(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::FLOOR:
c.frintm(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::CEIL:
c.frintp(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::SELECT:
c.fcmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0));
c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]);
break;
default:
// Just invoke evaluateOperation().
for (int element = 0; element < width; element++) {
for (int i = 0; i < (int) args.size(); i++) {
c.ins(argReg.s(0), workspaceVar[args[i]].s(element));
c.fcvt(doubleArgReg, argReg);
c.str(doubleArgReg, arm::ptr(argsPointer, 8*i));
}
arm::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) evaluateOperation));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
invoke->setArg(0, imm(&op));
invoke->setArg(1, imm(&argValues[0]));
invoke->setRet(0, doubleResultReg);
c.fcvt(argReg, doubleResultReg);
c.ins(workspaceVar[target[step]].s(element), argReg.s(0));
}
}
}
arm::Gp resultPointer = c.newIntPtr();
c.mov(resultPointer, imm(&workspace[workspace.size()-width]));
c.str(workspaceVar.back().s4(), arm::ptr(resultPointer, 0));
c.endFunc();
c.finalize();
runtime.add(&jitCode, &code);
}
void CompiledVectorExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, float (*function)(float)) {
arm::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
arm::Vec a = c.newVecS();
arm::Vec d = c.newVecS();
for (int element = 0; element < width; element++) {
c.ins(a.s(0), arg.s(element));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<float, float>());
invoke->setArg(0, a);
invoke->setRet(0, d);
c.ins(dest.s(element), d.s(0));
}
}
void CompiledVectorExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, float (*function)(float, float)) {
arm::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
arm::Vec a1 = c.newVecS();
arm::Vec a2 = c.newVecS();
arm::Vec d = c.newVecS();
for (int element = 0; element < width; element++) {
c.ins(a1.s(0), arg1.s(element));
c.ins(a2.s(0), arg2.s(element));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<float, float, float>());
invoke->setArg(0, a1);
invoke->setArg(1, a2);
invoke->setRet(0, d);
c.ins(dest.s(element), d.s(0));
}
}
#else
void CompiledVectorExpression::generateJitCode() {
const CpuInfo& cpu = CpuInfo::host();
if (!cpu.hasFeature(CpuFeatures::X86::kAVX))
return;
CodeHolder code;
code.init(runtime.environment());
x86::Compiler c(&code);
FuncNode* funcNode = c.addFunc(FuncSignatureT<double>());
funcNode->frame().setAvxEnabled();
vector<x86::Ymm> workspaceVar(workspace.size()/width);
for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newYmmPs();
x86::Gp argsPointer = c.newIntPtr();
c.mov(argsPointer, imm(&argValues[0]));
vector<vector<int> > groups, groupPowers;
vector<int> stepGroup;
findPowerGroups(groups, groupPowers, stepGroup);
// Load the arguments into variables.
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter);
x86::Gp variablePointer = c.newIntPtr();
c.mov(variablePointer, imm(getVariablePointer(index->first)));
if (width == 4)
c.vmovdqu(workspaceVar[index->second].xmm(), x86::ptr(variablePointer, 0, 0));
else
c.vmovdqu(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
}
// Make a list of all constants that will be needed for evaluation.
vector<int> operationConstantIndex(operation.size(), -1);
for (int step = 0; step < (int) operation.size(); step++) {
// Find the constant value (if any) used by this operation.
Operation& op = *operation[step];
double value;
if (op.getId() == Operation::CONSTANT)
value = dynamic_cast<Operation::Constant&> (op).getValue();
else if (op.getId() == Operation::ADD_CONSTANT)
value = dynamic_cast<Operation::AddConstant&> (op).getValue();
else if (op.getId() == Operation::MULTIPLY_CONSTANT)
value = dynamic_cast<Operation::MultiplyConstant&> (op).getValue();
else if (op.getId() == Operation::RECIPROCAL)
value = 1.0;
else if (op.getId() == Operation::STEP)
value = 1.0;
else if (op.getId() == Operation::DELTA)
value = 1.0;
else if (op.getId() == Operation::ABS) {
int mask = 0x7FFFFFFF;
value = *reinterpret_cast<float*>(&mask);
}
else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&> (op).getValue();
else
value = 1.0;
} else
continue;
// See if we already have a variable for this constant.
for (int i = 0; i < (int) constants.size(); i++)
if (value == constants[i]) {
operationConstantIndex[step] = i;
break;
}
if (operationConstantIndex[step] == -1) {
operationConstantIndex[step] = constants.size();
constants.push_back(value);
}
}
// Load constants into variables.
vector<x86::Ymm> constantVar(constants.size());
if (constants.size() > 0) {
x86::Gp constantsPointer = c.newIntPtr();
c.mov(constantsPointer, imm(&constants[0]));
for (int i = 0; i < (int) constants.size(); i++) {
constantVar[i] = c.newYmmPs();
c.vbroadcastss(constantVar[i], x86::ptr(constantsPointer, 4*i, 0));
}
}
// Evaluate the operations.
vector<bool> hasComputedPower(operation.size(), false);
x86::Ymm argReg = c.newYmm();
x86::Ymm doubleArgReg = c.newYmm();
x86::Ymm doubleResultReg = c.newYmm();
for (int step = 0; step < (int) operation.size(); step++) {
if (hasComputedPower[step])
continue;
// When one or more steps involve raising the same argument to multiple integer
// powers, we can compute them all together for efficiency.
if (stepGroup[step] != -1) {
vector<int>& group = groups[stepGroup[step]];
vector<int>& powers = groupPowers[stepGroup[step]];
x86::Ymm multiplier = c.newYmmPs();
if (powers[0] > 0)
c.vmovdqu(multiplier, workspaceVar[arguments[step][0]]);
else {
c.vdivps(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i];
}
vector<bool> hasAssigned(group.size(), false);
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < group.size(); i++) {
if (powers[i] % 2 == 1) {
if (!hasAssigned[i])
c.vmovdqu(workspaceVar[target[group[i]]], multiplier);
else
c.vmulps(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
hasAssigned[i] = true;
}
powers[i] >>= 1;
if (powers[i] != 0)
done = false;
}
if (!done)
c.vmulps(multiplier, multiplier, multiplier);
}
for (int step : group)
hasComputedPower[step] = true;
continue;
}
// Evaluate the step.
Operation& op = *operation[step];
vector<int> args = arguments[step];
if (args.size() == 1) {
// One or more sequential arguments. Fill out the list.
for (int i = 1; i < op.getNumArguments(); i++)
args.push_back(args[0] + i);
}
// Generate instructions to execute this operation.
switch (op.getId()) {
case Operation::CONSTANT:
c.vmovdqu(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::ADD:
c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::SUBTRACT:
c.vsubps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MULTIPLY:
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::DIVIDE:
c.vdivps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::POWER:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf);
break;
case Operation::NEGATE:
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vsubps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::SQRT:
c.vsqrtps(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf);
break;
case Operation::LOG:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf);
break;
case Operation::SIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf);
break;
case Operation::COS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf);
break;
case Operation::TAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf);
break;
case Operation::ASIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf);
break;
case Operation::ACOS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf);
break;
case Operation::ATAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf);
break;
case Operation::ATAN2:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f);
break;
case Operation::SINH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf);
break;
case Operation::COSH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf);
break;
case Operation::TANH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf);
break;
case Operation::STEP:
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::DELTA:
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OQ = 0
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::SQUARE:
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
break;
case Operation::CUBE:
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
c.vmulps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::RECIPROCAL:
c.vdivps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]);
break;
case Operation::ADD_CONSTANT:
c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::MULTIPLY_CONSTANT:
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf);
break;
case Operation::MIN:
c.vminps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MAX:
c.vmaxps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::ABS:
c.vandps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::FLOOR:
c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(1));
break;
case Operation::CEIL:
c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(2));
break;
case Operation::SELECT:
{
x86::Ymm mask = c.newYmmPs();
c.vxorps(mask, mask, mask);
c.vcmpps(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0
c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask);
break;
}
default:
// Just invoke evaluateOperation().
for (int element = 0; element < width; element++) {
for (int i = 0; i < (int) args.size(); i++) {
if (element < 4)
c.vshufps(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(element));
else {
c.vperm2f128(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(1));
c.vshufps(argReg, argReg, argReg, imm(element-4));
}
c.vcvtss2sd(doubleArgReg.xmm(), doubleArgReg.xmm(), argReg.xmm());
c.vmovsd(x86::ptr(argsPointer, 8*i, 0), doubleArgReg.xmm());
}
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) evaluateOperation));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
invoke->setArg(0, imm(&op));
invoke->setArg(1, imm(&argValues[0]));
invoke->setRet(0, doubleResultReg);
c.vcvtsd2ss(argReg.xmm(), argReg.xmm(), doubleResultReg.xmm());
if (element > 3)
c.vperm2f128(argReg, argReg, argReg, imm(0));
if (element != 0)
c.vshufps(argReg, argReg, argReg, imm(0));
c.vblendps(workspaceVar[target[step]], workspaceVar[target[step]], argReg, 1<<element);
}
}
}
x86::Gp resultPointer = c.newIntPtr();
c.mov(resultPointer, imm(&workspace[workspace.size()-width]));
if (width == 4)
c.vmovdqu(x86::ptr(resultPointer, 0, 0), workspaceVar.back().xmm());
else
c.vmovdqu(x86::ptr(resultPointer, 0, 0), workspaceVar.back());
c.endFunc();
c.finalize();
runtime.add(&jitCode, &code);
}
void CompiledVectorExpression::generateSingleArgCall(x86::Compiler& c, x86::Ymm& dest, x86::Ymm& arg, float (*function)(float)) {
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
x86::Ymm a = c.newYmm();
x86::Ymm d = c.newYmm();
for (int element = 0; element < width; element++) {
if (element < 4)
c.vshufps(a, arg, arg, imm(element));
else {
c.vperm2f128(a, arg, arg, imm(1));
c.vshufps(a, a, a, imm(element-4));
}
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<float, float>());
invoke->setArg(0, a);
invoke->setRet(0, d);
if (element > 3)
c.vperm2f128(d, d, d, imm(0));
if (element != 0)
c.vshufps(d, d, d, imm(0));
c.vblendps(dest, dest, d, 1<<element);
}
}
void CompiledVectorExpression::generateTwoArgCall(x86::Compiler& c, x86::Ymm& dest, x86::Ymm& arg1, x86::Ymm& arg2, float (*function)(float, float)) {
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
x86::Ymm a1 = c.newYmm();
x86::Ymm a2 = c.newYmm();
x86::Ymm d = c.newYmm();
for (int element = 0; element < width; element++) {
if (element < 4) {
c.vshufps(a1, arg1, arg1, imm(element));
c.vshufps(a2, arg2, arg2, imm(element));
}
else {
c.vperm2f128(a1, arg1, arg1, imm(1));
c.vperm2f128(a2, arg2, arg2, imm(1));
c.vshufps(a1, a1, a1, imm(element-4));
c.vshufps(a2, a2, a2, imm(element-4));
}
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<float, float, float>());
invoke->setArg(0, a1);
invoke->setArg(1, a2);
invoke->setRet(0, d);
if (element > 3)
c.vperm2f128(d, d, d, imm(0));
if (element != 0)
c.vshufps(d, d, d, imm(0));
c.vblendps(dest, dest, d, 1<<element);
}
}
#endif
#endif
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. * * Portions copyright (c) 2009-2022 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include "lepton/CompiledExpression.h" #include "lepton/CompiledExpression.h"
#include "lepton/CompiledVectorExpression.h"
#include "lepton/ExpressionProgram.h" #include "lepton/ExpressionProgram.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
#include <limits> #include <limits>
...@@ -373,6 +374,10 @@ CompiledExpression ParsedExpression::createCompiledExpression() const { ...@@ -373,6 +374,10 @@ CompiledExpression ParsedExpression::createCompiledExpression() const {
return CompiledExpression(*this); return CompiledExpression(*this);
} }
CompiledVectorExpression ParsedExpression::createCompiledVectorExpression(int width) const {
return CompiledVectorExpression(*this, width);
}
ParsedExpression ParsedExpression::renameVariables(const map<string, string>& replacements) const { ParsedExpression ParsedExpression::renameVariables(const map<string, string>& replacements) const {
return ParsedExpression(renameNodeVariables(getRootNode(), replacements)); return ParsedExpression(renameNodeVariables(getRootNode(), replacements));
} }
......
#include "../libraries/lepton/include/Lepton.h" #include "../libraries/lepton/include/Lepton.h"
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "lepton/CompiledVectorExpression.h"
#include <iostream> #include <iostream>
#include <limits> #include <limits>
...@@ -101,7 +102,7 @@ void verifyEvaluation(const string& expression, double x, double y, double expec ...@@ -101,7 +102,7 @@ void verifyEvaluation(const string& expression, double x, double y, double expec
compiled.getVariableReference("y") = y; compiled.getVariableReference("y") = y;
value = compiled.evaluate(); value = compiled.evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10); ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Try specifying memory locations for the compiled expression. // Try specifying memory locations for the compiled expression.
map<string, double*> variablePointers; map<string, double*> variablePointers;
...@@ -114,6 +115,41 @@ void verifyEvaluation(const string& expression, double x, double y, double expec ...@@ -114,6 +115,41 @@ void verifyEvaluation(const string& expression, double x, double y, double expec
ASSERT_EQUAL(&x, &compiled2.getVariableReference("x")); ASSERT_EQUAL(&x, &compiled2.getVariableReference("x"));
ASSERT_EQUAL(&y, &compiled2.getVariableReference("y")); ASSERT_EQUAL(&y, &compiled2.getVariableReference("y"));
// Try evaluating it as a vector.
for (int width : CompiledVectorExpression::getAllowedWidths()) {
CompiledVectorExpression vector = parsed.createCompiledVectorExpression(width);
for (int i = 0; i < width; i++) {
if (vector.getVariables().find("x") != vector.getVariables().end())
for (int j = 0; j < width; j++)
vector.getVariablePointer("x")[j] = (i == j ? x : -100.0);
if (vector.getVariables().find("y") != vector.getVariables().end())
for (int j = 0; j < width; j++)
vector.getVariablePointer("y")[j] = (i == j ? y : -100.0);
const float* result = vector.evaluate();
ASSERT_EQUAL_TOL(expectedValue, result[i], 1e-6);
}
}
// Specify memory locations for the vector expression.
float xvec[8], yvec[8];
map<string, float*> vecVariablePointers;
vecVariablePointers["x"] = xvec;
vecVariablePointers["y"] = yvec;
for (int width : CompiledVectorExpression::getAllowedWidths()) {
CompiledVectorExpression vector2 = parsed.createCompiledVectorExpression(width);
vector2.setVariableLocations(vecVariablePointers);
for (int i = 0; i < width; i++) {
for (int j = 0; j < width; j++) {
xvec[j] = (i == j ? x : -100.0);
yvec[j] = (i == j ? y : -100.0);
}
const float* result = vector2.evaluate();
ASSERT_EQUAL_TOL(expectedValue, result[i], 1e-6);
}
}
// Make sure that variable renaming works. // Make sure that variable renaming works.
variables.clear(); variables.clear();
...@@ -143,12 +179,12 @@ void verifyInvalidExpression(const string& expression) { ...@@ -143,12 +179,12 @@ void verifyInvalidExpression(const string& expression) {
* Verify that two numbers have the same value. * Verify that two numbers have the same value.
*/ */
void assertNumbersEqual(double val1, double val2) { void assertNumbersEqual(double val1, double val2, double tol=1e-10) {
const double inf = numeric_limits<double>::infinity(); const double inf = numeric_limits<double>::infinity();
if (val1 == val1 || val2 == val2) // If both are NaN, that's fine. if (val1 == val1 || val2 == val2) // If both are NaN, that's fine.
if (val1 != inf || val2 != inf) // Both infinity is also fine. if (val1 != inf || val2 != inf) // Both infinity is also fine.
if (val1 != -inf || val2 != -inf) // Same for -infinity. if (val1 != -inf || val2 != -inf) // Same for -infinity.
ASSERT_EQUAL_TOL(val1, val2, 1e-10); ASSERT_EQUAL_TOL(val1, val2, tol);
} }
/** /**
...@@ -177,6 +213,31 @@ void verifySameValue(const ParsedExpression& exp1, const ParsedExpression& exp2, ...@@ -177,6 +213,31 @@ void verifySameValue(const ParsedExpression& exp1, const ParsedExpression& exp2,
compiled2.getVariableReference("y") = y; compiled2.getVariableReference("y") = y;
assertNumbersEqual(val1, compiled1.evaluate()); assertNumbersEqual(val1, compiled1.evaluate());
assertNumbersEqual(val2, compiled2.evaluate()); assertNumbersEqual(val2, compiled2.evaluate());
// Now check CompiledVectorizedExpressions.
for (int width : CompiledVectorExpression::getAllowedWidths()) {
CompiledVectorExpression vector1 = exp1.createCompiledVectorExpression(width);
CompiledVectorExpression vector2 = exp2.createCompiledVectorExpression(width);
for (int i = 0; i < width; i++) {
if (vector1.getVariables().find("x") != vector1.getVariables().end())
for (int j = 0; j < width; j++)
vector1.getVariablePointer("x")[j] = (i == j ? x : -100.0);
if (vector1.getVariables().find("y") != vector1.getVariables().end())
for (int j = 0; j < width; j++)
vector1.getVariablePointer("y")[j] = (i == j ? y : -100.0);
if (vector2.getVariables().find("x") != vector2.getVariables().end())
for (int j = 0; j < width; j++)
vector2.getVariablePointer("x")[j] = (i == j ? x : -100.0);
if (vector2.getVariables().find("y") != vector2.getVariables().end())
for (int j = 0; j < width; j++)
vector2.getVariablePointer("y")[j] = (i == j ? y : -100.0);
const float* result1 = vector1.evaluate();
const float* result2 = vector2.evaluate();
assertNumbersEqual(val1, result1[i], 1e-6);
assertNumbersEqual(val2, result2[i], 1e-6);
}
}
} }
/** /**
...@@ -235,6 +296,7 @@ int main() { ...@@ -235,6 +296,7 @@ int main() {
verifyEvaluation("2.1e-4*x*(y+1)", 3.0, 1.0, 1.26e-3); verifyEvaluation("2.1e-4*x*(y+1)", 3.0, 1.0, 1.26e-3);
verifyEvaluation("sin(2.5)", std::sin(2.5)); verifyEvaluation("sin(2.5)", std::sin(2.5));
verifyEvaluation("cot(x)", 3.0, 1.0, 1.0/std::tan(3.0)); verifyEvaluation("cot(x)", 3.0, 1.0, 1.0/std::tan(3.0));
verifyEvaluation("log(x)", 3.0, 1.0, std::log(3.0));
verifyEvaluation("x^2+y^3+x^-1+y^(1/2)", 1.0, 1.0, 4.0); verifyEvaluation("x^2+y^3+x^-1+y^(1/2)", 1.0, 1.0, 4.0);
verifyEvaluation("(2*x)*3", 4.0, 4.0, 24.0); verifyEvaluation("(2*x)*3", 4.0, 4.0, 24.0);
verifyEvaluation("(x*2)*3", 4.0, 4.0, 24.0); verifyEvaluation("(x*2)*3", 4.0, 4.0, 24.0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment