Unverified Commit aafb8b5b authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Vectorize custom expressions on CPU (#3552)

* Began implementing vectorization of Lepton expressions

* Tests for vector expressions

* Implemented CompiledVectorExpression for x86

* Bug fix

* Optimized select() on ARM

* Optimized select() on x86

* CompiledVectorExpression supports AVX

* Bug fix

* Updated docs

* Use VEX encoded instructions for CompiledExpression

* Optimized min() and max() on x86

* Optimized min() and max() on ARM

* Fixed compilation error

* Upgrade AsmJit
parent 6fb1c8a4
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009=2021 Stanford University and the Authors. *
* Portions copyright (c) 2009-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -41,6 +41,7 @@ namespace Lepton {
class CompiledExpression;
class ExpressionProgram;
class CompiledVectorExpression;
/**
* This class represents the result of parsing an expression. It provides methods for working with the
......@@ -102,6 +103,16 @@ public:
* Create a CompiledExpression that represents the same calculation as this expression.
*/
CompiledExpression createCompiledExpression() const;
/**
* Create a CompiledVectorExpression that allows the expression to be evaluated efficiently
* using the CPU's vector unit.
*
* @param width the width of the vectors to evaluate it on. The allowed values
* depend on the CPU. 4 is always allowed, and 8 is allowed on
* x86 processors with AVX. Call CompiledVectorExpression::getAllowedWidths()
* to query the allowed widths on the current processor.
*/
CompiledVectorExpression createCompiledVectorExpression(int width) const;
/**
* Create a new ParsedExpression which is identical to this one, except that the names of some
* variables have been changed.
......
......@@ -151,7 +151,7 @@ void CompiledExpression::setVariableLocations(map<string, double*>& variableLoca
if (workspace.size() > 0)
generateJitCode();
#else
#endif
// Make a list of all variables we will need to copy before evaluating the expression.
variablesToCopy.clear();
......@@ -160,13 +160,11 @@ void CompiledExpression::setVariableLocations(map<string, double*>& variableLoca
if (pointer != variablePointers.end())
variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second));
}
#endif
}
double CompiledExpression::evaluate() const {
#ifdef LEPTON_USE_JIT
return jitCode();
#else
if (jitCode)
return jitCode();
for (int i = 0; i < variablesToCopy.size(); i++)
*variablesToCopy[i].first = *variablesToCopy[i].second;
......@@ -183,7 +181,6 @@ double CompiledExpression::evaluate() const {
}
}
return workspace[workspace.size()-1];
#endif
}
#ifdef LEPTON_USE_JIT
......@@ -458,14 +455,24 @@ void CompiledExpression::generateJitCode() {
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
break;
case Operation::MIN:
c.fmin(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MAX:
c.fmax(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::ABS:
c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::FLOOR:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor);
c.frintm(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::CEIL:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil);
c.frintp(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::SELECT:
c.fcmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0));
c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]);
break;
default:
// Just invoke evaluateOperation().
......@@ -507,10 +514,14 @@ void CompiledExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, ar
}
#else
void CompiledExpression::generateJitCode() {
const CpuInfo& cpu = CpuInfo::host();
if (!cpu.hasFeature(CpuFeatures::X86::kAVX))
return;
CodeHolder code;
code.init(runtime.environment());
x86::Compiler c(&code);
c.addFunc(FuncSignatureT<double>());
FuncNode* funcNode = c.addFunc(FuncSignatureT<double>());
funcNode->frame().setAvxEnabled();
vector<x86::Xmm> workspaceVar(workspace.size());
for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newXmmSd();
......@@ -522,11 +533,11 @@ void CompiledExpression::generateJitCode() {
// Load the arguments into variables.
x86::Gp variablePointer = c.newIntPtr();
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter);
x86::Gp variablePointer = c.newIntPtr();
c.mov(variablePointer, imm(&getVariableReference(index->first)));
c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
c.vmovsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
}
// Make a list of all constants that will be needed for evaluation.
......@@ -549,6 +560,10 @@ void CompiledExpression::generateJitCode() {
value = 1.0;
else if (op.getId() == Operation::DELTA)
value = 1.0;
else if (op.getId() == Operation::ABS) {
long long mask = 0x7FFFFFFFFFFFFFFF;
value = *reinterpret_cast<double*>(&mask);
}
else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
......@@ -579,7 +594,7 @@ void CompiledExpression::generateJitCode() {
c.mov(constantsPointer, imm(&constants[0]));
for (int i = 0; i < (int) constants.size(); i++) {
constantVar[i] = c.newXmmSd();
c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0));
c.vmovsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0));
}
}
......@@ -598,10 +613,9 @@ void CompiledExpression::generateJitCode() {
vector<int>& powers = groupPowers[stepGroup[step]];
x86::Xmm multiplier = c.newXmmSd();
if (powers[0] > 0)
c.movsd(multiplier, workspaceVar[arguments[step][0]]);
c.vmovsd(multiplier, workspaceVar[arguments[step][0]], workspaceVar[arguments[step][0]]);
else {
c.movsd(multiplier, constantVar[operationConstantIndex[step]]);
c.divsd(multiplier, workspaceVar[arguments[step][0]]);
c.vdivsd(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i];
}
......@@ -612,9 +626,9 @@ void CompiledExpression::generateJitCode() {
for (int i = 0; i < group.size(); i++) {
if (powers[i]%2 == 1) {
if (!hasAssigned[i])
c.movsd(workspaceVar[target[group[i]]], multiplier);
c.vmovsd(workspaceVar[target[group[i]]], multiplier, multiplier);
else
c.mulsd(workspaceVar[target[group[i]]], multiplier);
c.vmulsd(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
hasAssigned[i] = true;
}
powers[i] >>= 1;
......@@ -622,7 +636,7 @@ void CompiledExpression::generateJitCode() {
done = false;
}
if (!done)
c.mulsd(multiplier, multiplier);
c.vmulsd(multiplier, multiplier, multiplier);
}
for (int step : group)
hasComputedPower[step] = true;
......@@ -644,33 +658,29 @@ void CompiledExpression::generateJitCode() {
switch (op.getId()) {
case Operation::CONSTANT:
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.vmovsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::ADD:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]);
c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::SUBTRACT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]);
c.vsubsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MULTIPLY:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]);
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::DIVIDE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]);
c.vdivsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::POWER:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow);
break;
case Operation::NEGATE:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vsubsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::SQRT:
c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.vsqrtsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
break;
case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
......@@ -709,53 +719,62 @@ void CompiledExpression::generateJitCode() {
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
break;
case Operation::STEP:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::DELTA:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::SQUARE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
break;
case Operation::CUBE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
c.vmulsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::RECIPROCAL:
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.vdivsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]);
break;
case Operation::ADD_CONSTANT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::MULTIPLY_CONSTANT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
break;
case Operation::MIN:
c.vminsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MAX:
c.vmaxsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::ABS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
c.vandpd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::FLOOR:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor);
c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(1));
break;
case Operation::CEIL:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil);
c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(2));
break;
case Operation::SELECT:
{
x86::Xmm mask = c.newXmmSd();
c.vxorps(mask, mask, mask);
c.vcmpsd(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0
c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask);
break;
}
default:
// Just invoke evaluateOperation().
for (int i = 0; i < (int) args.size(); i++)
c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
c.vmovsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) evaluateOperation));
InvokeNode* invoke;
......
This diff is collapsed.
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Portions copyright (c) 2009-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -31,6 +31,7 @@
#include "lepton/ParsedExpression.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CompiledVectorExpression.h"
#include "lepton/ExpressionProgram.h"
#include "lepton/Operation.h"
#include <limits>
......@@ -373,6 +374,10 @@ CompiledExpression ParsedExpression::createCompiledExpression() const {
return CompiledExpression(*this);
}
CompiledVectorExpression ParsedExpression::createCompiledVectorExpression(int width) const {
return CompiledVectorExpression(*this, width);
}
ParsedExpression ParsedExpression::renameVariables(const map<string, string>& replacements) const {
return ParsedExpression(renameNodeVariables(getRootNode(), replacements));
}
......
#include "../libraries/lepton/include/Lepton.h"
#include "openmm/internal/AssertionUtilities.h"
#include "lepton/CompiledVectorExpression.h"
#include <iostream>
#include <limits>
......@@ -101,7 +102,7 @@ void verifyEvaluation(const string& expression, double x, double y, double expec
compiled.getVariableReference("y") = y;
value = compiled.evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Try specifying memory locations for the compiled expression.
map<string, double*> variablePointers;
......@@ -114,6 +115,41 @@ void verifyEvaluation(const string& expression, double x, double y, double expec
ASSERT_EQUAL(&x, &compiled2.getVariableReference("x"));
ASSERT_EQUAL(&y, &compiled2.getVariableReference("y"));
// Try evaluating it as a vector.
for (int width : CompiledVectorExpression::getAllowedWidths()) {
CompiledVectorExpression vector = parsed.createCompiledVectorExpression(width);
for (int i = 0; i < width; i++) {
if (vector.getVariables().find("x") != vector.getVariables().end())
for (int j = 0; j < width; j++)
vector.getVariablePointer("x")[j] = (i == j ? x : -100.0);
if (vector.getVariables().find("y") != vector.getVariables().end())
for (int j = 0; j < width; j++)
vector.getVariablePointer("y")[j] = (i == j ? y : -100.0);
const float* result = vector.evaluate();
ASSERT_EQUAL_TOL(expectedValue, result[i], 1e-6);
}
}
// Specify memory locations for the vector expression.
float xvec[8], yvec[8];
map<string, float*> vecVariablePointers;
vecVariablePointers["x"] = xvec;
vecVariablePointers["y"] = yvec;
for (int width : CompiledVectorExpression::getAllowedWidths()) {
CompiledVectorExpression vector2 = parsed.createCompiledVectorExpression(width);
vector2.setVariableLocations(vecVariablePointers);
for (int i = 0; i < width; i++) {
for (int j = 0; j < width; j++) {
xvec[j] = (i == j ? x : -100.0);
yvec[j] = (i == j ? y : -100.0);
}
const float* result = vector2.evaluate();
ASSERT_EQUAL_TOL(expectedValue, result[i], 1e-6);
}
}
// Make sure that variable renaming works.
variables.clear();
......@@ -143,12 +179,12 @@ void verifyInvalidExpression(const string& expression) {
* Verify that two numbers have the same value.
*/
void assertNumbersEqual(double val1, double val2) {
void assertNumbersEqual(double val1, double val2, double tol=1e-10) {
const double inf = numeric_limits<double>::infinity();
if (val1 == val1 || val2 == val2) // If both are NaN, that's fine.
if (val1 != inf || val2 != inf) // Both infinity is also fine.
if (val1 != -inf || val2 != -inf) // Same for -infinity.
ASSERT_EQUAL_TOL(val1, val2, 1e-10);
ASSERT_EQUAL_TOL(val1, val2, tol);
}
/**
......@@ -177,6 +213,31 @@ void verifySameValue(const ParsedExpression& exp1, const ParsedExpression& exp2,
compiled2.getVariableReference("y") = y;
assertNumbersEqual(val1, compiled1.evaluate());
assertNumbersEqual(val2, compiled2.evaluate());
// Now check CompiledVectorizedExpressions.
for (int width : CompiledVectorExpression::getAllowedWidths()) {
CompiledVectorExpression vector1 = exp1.createCompiledVectorExpression(width);
CompiledVectorExpression vector2 = exp2.createCompiledVectorExpression(width);
for (int i = 0; i < width; i++) {
if (vector1.getVariables().find("x") != vector1.getVariables().end())
for (int j = 0; j < width; j++)
vector1.getVariablePointer("x")[j] = (i == j ? x : -100.0);
if (vector1.getVariables().find("y") != vector1.getVariables().end())
for (int j = 0; j < width; j++)
vector1.getVariablePointer("y")[j] = (i == j ? y : -100.0);
if (vector2.getVariables().find("x") != vector2.getVariables().end())
for (int j = 0; j < width; j++)
vector2.getVariablePointer("x")[j] = (i == j ? x : -100.0);
if (vector2.getVariables().find("y") != vector2.getVariables().end())
for (int j = 0; j < width; j++)
vector2.getVariablePointer("y")[j] = (i == j ? y : -100.0);
const float* result1 = vector1.evaluate();
const float* result2 = vector2.evaluate();
assertNumbersEqual(val1, result1[i], 1e-6);
assertNumbersEqual(val2, result2[i], 1e-6);
}
}
}
/**
......@@ -235,6 +296,7 @@ int main() {
verifyEvaluation("2.1e-4*x*(y+1)", 3.0, 1.0, 1.26e-3);
verifyEvaluation("sin(2.5)", std::sin(2.5));
verifyEvaluation("cot(x)", 3.0, 1.0, 1.0/std::tan(3.0));
verifyEvaluation("log(x)", 3.0, 1.0, std::log(3.0));
verifyEvaluation("x^2+y^3+x^-1+y^(1/2)", 1.0, 1.0, 4.0);
verifyEvaluation("(2*x)*3", 4.0, 4.0, 24.0);
verifyEvaluation("(x*2)*3", 4.0, 4.0, 24.0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment