"vscode:/vscode.git/clone" did not exist on "143fe36df644a3fd605491da44bbc4875dbe0d31"
Unverified Commit e1a926f4 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimized computing powers in CompiledExpression (#3520)

* Optimized computing powers in CompiledExpression

* Fixed compilation error

* Attempt at fixing compilation error
parent cc7018ea
...@@ -7,6 +7,7 @@ dependencies: ...@@ -7,6 +7,7 @@ dependencies:
- cmake - cmake
- make - make
- ccache - ccache
- sysroot_linux-64 2.17
# host # host
- python - python
- cython - cython
......
...@@ -105,6 +105,7 @@ private: ...@@ -105,6 +105,7 @@ private:
std::map<std::string, double> dummyVariables; std::map<std::string, double> dummyVariables;
double (*jitCode)(); double (*jitCode)();
#ifdef LEPTON_USE_JIT #ifdef LEPTON_USE_JIT
void findPowerGroups(std::vector<std::vector<int> >& groups, std::vector<std::vector<int> >& groupPowers, std::vector<int>& stepGroup);
void generateJitCode(); void generateJitCode();
#if defined(__ARM__) || defined(__ARM64__) #if defined(__ARM__) || defined(__ARM64__)
void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, double (*function)(double)); void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, double (*function)(double));
......
...@@ -192,6 +192,48 @@ static double evaluateOperation(Operation* op, double* args) { ...@@ -192,6 +192,48 @@ static double evaluateOperation(Operation* op, double* args) {
return op->evaluate(args, dummyVariables); return op->evaluate(args, dummyVariables);
} }
void CompiledExpression::findPowerGroups(vector<vector<int> >& groups, vector<vector<int> >& groupPowers, vector<int>& stepGroup) {
// Identify every step that raises an argument to an integer power.
vector<int> stepPower(operation.size(), 0);
vector<int> stepArg(operation.size(), -1);
for (int step = 0; step < operation.size(); step++) {
Operation& op = *operation[step];
int power = 0;
if (op.getId() == Operation::SQUARE)
power = 2;
else if (op.getId() == Operation::CUBE)
power = 3;
else if (op.getId() == Operation::POWER_CONSTANT) {
double realPower = dynamic_cast<const Operation::PowerConstant*>(&op)->getValue();
if (realPower == (int) realPower)
power = (int) realPower;
}
if (power != 0) {
stepPower[step] = power;
stepArg[step] = arguments[step][0];
}
}
// Find groups that operate on the same argument and whose powers have the same sign.
stepGroup.resize(operation.size(), -1);
for (int i = 0; i < operation.size(); i++) {
if (stepGroup[i] != -1)
continue;
vector<int> group, power;
for (int j = i; j < operation.size(); j++) {
if (stepArg[i] == stepArg[j] && stepPower[i]*stepPower[j] > 0) {
stepGroup[j] = groups.size();
group.push_back(j);
power.push_back(stepPower[j]);
}
}
groups.push_back(group);
groupPowers.push_back(power);
}
}
#if defined(__ARM__) || defined(__ARM64__) #if defined(__ARM__) || defined(__ARM64__)
void CompiledExpression::generateJitCode() { void CompiledExpression::generateJitCode() {
CodeHolder code; CodeHolder code;
...@@ -203,6 +245,9 @@ void CompiledExpression::generateJitCode() { ...@@ -203,6 +245,9 @@ void CompiledExpression::generateJitCode() {
workspaceVar[i] = c.newVecD(); workspaceVar[i] = c.newVecD();
arm::Gp argsPointer = c.newIntPtr(); arm::Gp argsPointer = c.newIntPtr();
c.mov(argsPointer, imm(&argValues[0])); c.mov(argsPointer, imm(&argValues[0]));
vector<vector<int> > groups, groupPowers;
vector<int> stepGroup;
findPowerGroups(groups, groupPowers, stepGroup);
// Load the arguments into variables. // Load the arguments into variables.
...@@ -233,6 +278,12 @@ void CompiledExpression::generateJitCode() { ...@@ -233,6 +278,12 @@ void CompiledExpression::generateJitCode() {
value = 1.0; value = 1.0;
else if (op.getId() == Operation::DELTA) else if (op.getId() == Operation::DELTA)
value = 1.0; value = 1.0;
else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
else
value = 1.0;
}
else else
continue; continue;
...@@ -260,10 +311,54 @@ void CompiledExpression::generateJitCode() { ...@@ -260,10 +311,54 @@ void CompiledExpression::generateJitCode() {
c.ldr(constantVar[i], arm::ptr(constantsPointer, 8*i)); c.ldr(constantVar[i], arm::ptr(constantsPointer, 8*i));
} }
} }
// Evaluate the operations. // Evaluate the operations.
vector<bool> hasComputedPower(operation.size(), false);
for (int step = 0; step < (int) operation.size(); step++) { for (int step = 0; step < (int) operation.size(); step++) {
if (hasComputedPower[step])
continue;
// When one or more steps involve raising the same argument to multiple integer
// powers, we can compute them all together for efficiency.
if (stepGroup[step] != -1) {
vector<int>& group = groups[stepGroup[step]];
vector<int>& powers = groupPowers[stepGroup[step]];
arm::Vec multiplier = c.newVecD();
if (powers[0] > 0)
c.fmov(multiplier, workspaceVar[arguments[step][0]]);
else {
c.fdiv(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i];
}
vector<bool> hasAssigned(group.size(), false);
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < group.size(); i++) {
if (powers[i]%2 == 1) {
if (!hasAssigned[i])
c.fmov(workspaceVar[target[group[i]]], multiplier);
else
c.fmul(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
hasAssigned[i] = true;
}
powers[i] >>= 1;
if (powers[i] != 0)
done = false;
}
if (!done)
c.fmul(multiplier, multiplier, multiplier);
}
for (int step : group)
hasComputedPower[step] = true;
continue;
}
// Evaluate the step.
Operation& op = *operation[step]; Operation& op = *operation[step];
vector<int> args = arguments[step]; vector<int> args = arguments[step];
if (args.size() == 1) { if (args.size() == 1) {
...@@ -360,6 +455,9 @@ void CompiledExpression::generateJitCode() { ...@@ -360,6 +455,9 @@ void CompiledExpression::generateJitCode() {
case Operation::MULTIPLY_CONSTANT: case Operation::MULTIPLY_CONSTANT:
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
break;
case Operation::ABS: case Operation::ABS:
c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]); c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]);
break; break;
...@@ -418,7 +516,10 @@ void CompiledExpression::generateJitCode() { ...@@ -418,7 +516,10 @@ void CompiledExpression::generateJitCode() {
workspaceVar[i] = c.newXmmSd(); workspaceVar[i] = c.newXmmSd();
x86::Gp argsPointer = c.newIntPtr(); x86::Gp argsPointer = c.newIntPtr();
c.mov(argsPointer, imm(&argValues[0])); c.mov(argsPointer, imm(&argValues[0]));
vector<vector<int> > groups, groupPowers;
vector<int> stepGroup;
findPowerGroups(groups, groupPowers, stepGroup);
// Load the arguments into variables. // Load the arguments into variables.
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
...@@ -448,6 +549,12 @@ void CompiledExpression::generateJitCode() { ...@@ -448,6 +549,12 @@ void CompiledExpression::generateJitCode() {
value = 1.0; value = 1.0;
else if (op.getId() == Operation::DELTA) else if (op.getId() == Operation::DELTA)
value = 1.0; value = 1.0;
else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
else
value = 1.0;
}
else else
continue; continue;
...@@ -478,7 +585,52 @@ void CompiledExpression::generateJitCode() { ...@@ -478,7 +585,52 @@ void CompiledExpression::generateJitCode() {
// Evaluate the operations. // Evaluate the operations.
vector<bool> hasComputedPower(operation.size(), false);
for (int step = 0; step < (int) operation.size(); step++) { for (int step = 0; step < (int) operation.size(); step++) {
if (hasComputedPower[step])
continue;
// When one or more steps involve raising the same argument to multiple integer
// powers, we can compute them all together for efficiency.
if (stepGroup[step] != -1) {
vector<int>& group = groups[stepGroup[step]];
vector<int>& powers = groupPowers[stepGroup[step]];
x86::Xmm multiplier = c.newXmmSd();
if (powers[0] > 0)
c.movsd(multiplier, workspaceVar[arguments[step][0]]);
else {
c.movsd(multiplier, constantVar[operationConstantIndex[step]]);
c.divsd(multiplier, workspaceVar[arguments[step][0]]);
for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i];
}
vector<bool> hasAssigned(group.size(), false);
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < group.size(); i++) {
if (powers[i]%2 == 1) {
if (!hasAssigned[i])
c.movsd(workspaceVar[target[group[i]]], multiplier);
else
c.mulsd(workspaceVar[target[group[i]]], multiplier);
hasAssigned[i] = true;
}
powers[i] >>= 1;
if (powers[i] != 0)
done = false;
}
if (!done)
c.mulsd(multiplier, multiplier);
}
for (int step : group)
hasComputedPower[step] = true;
continue;
}
// Evaluate the step.
Operation& op = *operation[step]; Operation& op = *operation[step];
vector<int> args = arguments[step]; vector<int> args = arguments[step];
if (args.size() == 1) { if (args.size() == 1) {
...@@ -587,6 +739,9 @@ void CompiledExpression::generateJitCode() { ...@@ -587,6 +739,9 @@ void CompiledExpression::generateJitCode() {
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break; break;
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
break;
case Operation::ABS: case Operation::ABS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs); generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
break; break;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment