Commit 1c416f54 authored by peastman's avatar peastman
Browse files

Optimization to CustomGBForce on CPU

parent a468fa3a
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013 Stanford University and the Authors. *
* Portions copyright (c) 2013-2016 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -37,6 +37,7 @@
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
#ifdef LEPTON_USE_JIT
#include "asmjit.h"
......@@ -73,6 +74,12 @@ public:
* to set the value of the variable before calling evaluate().
*/
double& getVariableReference(const std::string& name);
/**
* You can optionally specify the memory locations from which the values of variables should be read.
* This is useful, for example, when several expressions all use the same variable. You can then set
* the value of that variable in one place, and it will be seen by all of them.
*/
void setVariableLocations(std::map<std::string, double*>& variableLocations);
/**
* Evaluate the expression. The values of all variables should have been set before calling this.
*/
......@@ -82,6 +89,8 @@ private:
CompiledExpression(const ParsedExpression& expression);
void compileExpression(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
int findTempIndex(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
std::map<std::string, double*> variablePointers;
std::vector<std::pair<double*, double*> > variablesToCopy;
std::vector<std::vector<int> > arguments;
std::vector<int> target;
std::vector<Operation*> operation;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013 Stanford University and the Authors. *
* Portions copyright (c) 2013-2016 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -77,10 +77,7 @@ CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expr
operation.resize(expression.operation.size());
for (int i = 0; i < (int) operation.size(); i++)
operation[i] = expression.operation[i]->clone();
#ifdef LEPTON_USE_JIT
if (workspace.size() > 0)
generateJitCode();
#endif
setVariableLocations(variablePointers);
return *this;
}
......@@ -138,16 +135,41 @@ const set<string>& CompiledExpression::getVariables() const {
}
double& CompiledExpression::getVariableReference(const string& name) {
map<string, double*>::iterator pointer = variablePointers.find(name);
if (pointer != variablePointers.end())
return *pointer->second;
map<string, int>::iterator index = variableIndices.find(name);
if (index == variableIndices.end())
throw Exception("getVariableReference: Unknown variable '"+name+"'");
return workspace[index->second];
}
void CompiledExpression::setVariableLocations(map<string, double*>& variableLocations) {
variablePointers = variableLocations;
#ifdef LEPTON_USE_JIT
// Rebuild the JIT code.
if (workspace.size() > 0)
generateJitCode();
#else
// Make a list of all variables we will need to copy before evaluating the expression.
variablesToCopy.clear();
for (map<string, int>::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) {
map<string, double*>::iterator pointer = variablePointers.find(iter->first);
if (pointer != variablePointers.end())
variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second));
}
#endif
}
double CompiledExpression::evaluate() const {
#ifdef LEPTON_USE_JIT
return ((double (*)()) jitCode)();
#else
for (int i = 0; i < variablesToCopy.size(); i++)
*variablesToCopy[i].first = *variablesToCopy[i].second;
// Loop over the operations and evaluate each one.
for (int step = 0; step < operation.size(); step++) {
......@@ -176,16 +198,16 @@ void CompiledExpression::generateJitCode() {
vector<X86XmmVar> workspaceVar(workspace.size());
for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newXmmVar(kX86VarTypeXmmSd);
X86GpVar workspacePointer(c);
X86GpVar argsPointer(c);
c.mov(workspacePointer, imm_ptr(&workspace[0]));
c.mov(argsPointer, imm_ptr(&argValues[0]));
// Load the arguments into variables.
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter);
c.movsd(workspaceVar[index->second], x86::ptr(workspacePointer, 8*index->second, 0));
X86GpVar variablePointer(c);
c.mov(variablePointer, imm_ptr(&getVariableReference(index->first)));
c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
}
// Make a list of all constants that will be needed for evaluation.
......
/* Portions copyright (c) 2009-2014 Stanford University and Simbios.
/* Portions copyright (c) 2009-2016 Stanford University and Simbios.
* Contributors: Peter Eastman
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -254,15 +254,15 @@ public:
std::vector<std::vector<Lepton::CompiledExpression> > valueDerivExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > valueGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > valueParamDerivExpressions;
std::vector<int> valueIndex;
std::vector<double> value;
std::vector<Lepton::CompiledExpression> energyExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyDerivExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions;
std::vector<int> paramIndex;
std::vector<int> particleParamIndex;
std::vector<int> particleValueIndex;
int xindex, yindex, zindex, rindex;
std::vector<double> param;
std::vector<double> particleParam;
std::vector<double> particleValue;
double x, y, z, r;
int firstAtom, lastAtom;
// Workspace vectors
std::vector<float> value0, dVdR1, dVdR2, dVdX, dVdY, dVdZ;
......
......@@ -59,47 +59,68 @@ CpuCustomGBForce::ThreadData::ThreadData(int numAtoms, int numThreads, int threa
energyGradientExpressions(energyGradientExpressions), energyParamDerivExpressions(energyParamDerivExpressions) {
firstAtom = (threadIndex*(long long) numAtoms)/numThreads;
lastAtom = ((threadIndex+1)*(long long) numAtoms)/numThreads;
for (int i = 0; i < (int) valueExpressions.size(); i++)
map<string, double*> variableLocations;
variableLocations["x"] = &x;
variableLocations["y"] = &y;
variableLocations["z"] = &z;
variableLocations["r"] = &r;
param.resize(parameterNames.size());
particleParam.resize(parameterNames.size()*2);
for (int i = 0; i < (int) parameterNames.size(); i++) {
variableLocations[parameterNames[i]] = &param[i];
for (int j = 0; j < 2; j++) {
stringstream name;
name << parameterNames[i] << (j+1);
variableLocations[name.str()] = &particleParam[2*i+j];
}
}
value.resize(valueNames.size());
particleValue.resize(valueNames.size()*2);
for (int i = 0; i < (int) valueNames.size(); i++) {
variableLocations[valueNames[i]] = &value[i];
for (int j = 0; j < 2; j++) {
stringstream name;
name << valueNames[i] << (j+1);
variableLocations[name.str()] = &particleValue[2*i+j];
}
}
for (int i = 0; i < (int) valueExpressions.size(); i++) {
this->valueExpressions[i].setVariableLocations(variableLocations);
expressionSet.registerExpression(this->valueExpressions[i]);
}
for (int i = 0; i < (int) valueDerivExpressions.size(); i++)
for (int j = 0; j < (int) valueDerivExpressions[i].size(); j++)
for (int j = 0; j < (int) valueDerivExpressions[i].size(); j++) {
this->valueDerivExpressions[i][j].setVariableLocations(variableLocations);
expressionSet.registerExpression(this->valueDerivExpressions[i][j]);
}
for (int i = 0; i < (int) valueGradientExpressions.size(); i++)
for (int j = 0; j < (int) valueGradientExpressions[i].size(); j++)
for (int j = 0; j < (int) valueGradientExpressions[i].size(); j++) {
this->valueGradientExpressions[i][j].setVariableLocations(variableLocations);
expressionSet.registerExpression(this->valueGradientExpressions[i][j]);
}
for (int i = 0; i < (int) valueParamDerivExpressions.size(); i++)
for (int j = 0; j < (int) valueParamDerivExpressions[i].size(); j++)
for (int j = 0; j < (int) valueParamDerivExpressions[i].size(); j++) {
this->valueParamDerivExpressions[i][j].setVariableLocations(variableLocations);
expressionSet.registerExpression(this->valueParamDerivExpressions[i][j]);
for (int i = 0; i < (int) energyExpressions.size(); i++)
}
for (int i = 0; i < (int) energyExpressions.size(); i++) {
this->energyExpressions[i].setVariableLocations(variableLocations);
expressionSet.registerExpression(this->energyExpressions[i]);
}
for (int i = 0; i < (int) energyDerivExpressions.size(); i++)
for (int j = 0; j < (int) energyDerivExpressions[i].size(); j++)
for (int j = 0; j < (int) energyDerivExpressions[i].size(); j++) {
this->energyDerivExpressions[i][j].setVariableLocations(variableLocations);
expressionSet.registerExpression(this->energyDerivExpressions[i][j]);
}
for (int i = 0; i < (int) energyGradientExpressions.size(); i++)
for (int j = 0; j < (int) energyGradientExpressions[i].size(); j++)
for (int j = 0; j < (int) energyGradientExpressions[i].size(); j++) {
this->energyGradientExpressions[i][j].setVariableLocations(variableLocations);
expressionSet.registerExpression(this->energyGradientExpressions[i][j]);
}
for (int i = 0; i < (int) energyParamDerivExpressions.size(); i++)
for (int j = 0; j < (int) energyParamDerivExpressions[i].size(); j++)
for (int j = 0; j < (int) energyParamDerivExpressions[i].size(); j++) {
this->energyParamDerivExpressions[i][j].setVariableLocations(variableLocations);
expressionSet.registerExpression(this->energyParamDerivExpressions[i][j]);
xindex = expressionSet.getVariableIndex("x");
yindex = expressionSet.getVariableIndex("y");
zindex = expressionSet.getVariableIndex("z");
rindex = expressionSet.getVariableIndex("r");
for (int i = 0; i < (int) parameterNames.size(); i++) {
paramIndex.push_back(expressionSet.getVariableIndex(parameterNames[i]));
for (int j = 1; j < 3; j++) {
stringstream name;
name << parameterNames[i] << j;
particleParamIndex.push_back(expressionSet.getVariableIndex(name.str()));
}
}
for (int i = 0; i < (int) valueNames.size(); i++) {
valueIndex.push_back(expressionSet.getVariableIndex(valueNames[i]));
for (int j = 1; j < 3; j++) {
stringstream name;
name << valueNames[i] << j;
particleValueIndex.push_back(expressionSet.getVariableIndex(name.str()));
}
}
value0.resize(numAtoms);
dEdV.resize(valueNames.size());
......@@ -283,13 +304,13 @@ void CpuCustomGBForce::threadComputeForce(ThreadPool& threads, int threadIndex)
for (int j = 0; j < (int) threadData.size(); j++)
sum += threadData[j]->value0[atom];
values[0][atom] = sum;
data.expressionSet.setVariable(data.xindex, posq[4*atom]);
data.expressionSet.setVariable(data.yindex, posq[4*atom+1]);
data.expressionSet.setVariable(data.zindex, posq[4*atom+2]);
data.x = posq[4*atom];
data.y = posq[4*atom+1];
data.z = posq[4*atom+2];
for (int j = 0; j < numParams; j++)
data.expressionSet.setVariable(data.paramIndex[j], atomParameters[atom][j]);
data.param[j] = atomParameters[atom][j];
for (int i = 1; i < numValues; i++) {
data.expressionSet.setVariable(data.valueIndex[i-1], values[i-1][atom]);
data.value[i-1] = values[i-1][atom];
values[i][atom] = (float) data.valueExpressions[i].evaluate();
// Calculate derivatives with respect to parameters.
......@@ -397,15 +418,14 @@ void CpuCustomGBForce::calculateOnePairValue(int index, int atom1, int atom2, Th
getDeltaR(pos2, pos1, deltaR, r2, periodic, boxSize, invBoxSize);
if (cutoff && r2 >= cutoffDistance2)
return;
float r = sqrtf(r2);
data.r = sqrtf(r2);
for (int i = 0; i < numParams; i++) {
data.expressionSet.setVariable(data.particleParamIndex[i*2], atomParameters[atom1][i]);
data.expressionSet.setVariable(data.particleParamIndex[i*2+1], atomParameters[atom2][i]);
data.particleParam[i*2] = atomParameters[atom1][i];
data.particleParam[i*2+1] = atomParameters[atom2][i];
}
data.expressionSet.setVariable(data.rindex, r);
for (int i = 0; i < index; i++) {
data.expressionSet.setVariable(data.particleValueIndex[i*2], values[i][atom1]);
data.expressionSet.setVariable(data.particleValueIndex[i*2+1], values[i][atom2]);
data.particleValue[i*2] = values[i][atom1];
data.particleValue[i*2+1] = values[i][atom2];
}
valueArray[atom1] += (float) data.valueExpressions[index].evaluate();
......@@ -418,13 +438,13 @@ void CpuCustomGBForce::calculateOnePairValue(int index, int atom1, int atom2, Th
void CpuCustomGBForce::calculateSingleParticleEnergyTerm(int index, ThreadData& data, int numAtoms, float* posq,
RealOpenMM** atomParameters, float* forces, double& totalEnergy) {
for (int i = data.firstAtom; i < data.lastAtom; i++) {
data.expressionSet.setVariable(data.xindex, posq[4*i]);
data.expressionSet.setVariable(data.yindex, posq[4*i+1]);
data.expressionSet.setVariable(data.zindex, posq[4*i+2]);
data.x = posq[4*i];
data.y = posq[4*i+1];
data.z = posq[4*i+2];
for (int j = 0; j < numParams; j++)
data.expressionSet.setVariable(data.paramIndex[j], atomParameters[i][j]);
data.param[j] = atomParameters[i][j];
for (int j = 0; j < (int) values.size(); j++)
data.expressionSet.setVariable(data.valueIndex[j], values[j][i]);
data.value[j] = values[j][i];
if (includeEnergy)
totalEnergy += (float) data.energyExpressions[index].evaluate();
for (int j = 0; j < (int) values.size(); j++)
......@@ -493,18 +513,17 @@ void CpuCustomGBForce::calculateOnePairEnergyTerm(int index, int atom1, int atom
getDeltaR(pos2, pos1, deltaR, r2, periodic, boxSize, invBoxSize);
if (cutoff && r2 >= cutoffDistance2)
return;
float r = sqrtf(r2);
data.r = sqrtf(r2);
// Record variables for evaluating expressions.
for (int i = 0; i < numParams; i++) {
data.expressionSet.setVariable(data.particleParamIndex[i*2], atomParameters[atom1][i]);
data.expressionSet.setVariable(data.particleParamIndex[i*2+1], atomParameters[atom2][i]);
data.particleParam[i*2] = atomParameters[atom1][i];
data.particleParam[i*2+1] = atomParameters[atom2][i];
}
data.expressionSet.setVariable(data.rindex, r);
for (int i = 0; i < (int) values.size(); i++) {
data.expressionSet.setVariable(data.particleValueIndex[i*2], values[i][atom1]);
data.expressionSet.setVariable(data.particleValueIndex[i*2+1], values[i][atom2]);
data.particleValue[i*2] = values[i][atom1];
data.particleValue[i*2+1] = values[i][atom2];
}
// Evaluate the energy and its derivatives.
......@@ -512,7 +531,7 @@ void CpuCustomGBForce::calculateOnePairEnergyTerm(int index, int atom1, int atom
if (includeEnergy)
totalEnergy += (float) data.energyExpressions[index].evaluate();
float dEdR = (float) data.energyDerivExpressions[index][0].evaluate();
dEdR *= 1/r;
dEdR *= 1/data.r;
fvec4 result = deltaR*dEdR;
(fvec4(forces+4*atom1)-result).store(forces+4*atom1);
(fvec4(forces+4*atom2)+result).store(forces+4*atom2);
......@@ -571,13 +590,13 @@ void CpuCustomGBForce::calculateChainRuleForces(ThreadData& data, int numAtoms,
// Compute chain rule terms for computed values that depend explicitly on particle coordinates.
for (int i = data.firstAtom; i < data.lastAtom; i++) {
data.expressionSet.setVariable(data.xindex, posq[4*i]);
data.expressionSet.setVariable(data.yindex, posq[4*i+1]);
data.expressionSet.setVariable(data.zindex, posq[4*i+2]);
data.x = posq[4*i];
data.y = posq[4*i+1];
data.z = posq[4*i+2];
for (int j = 0; j < numParams; j++)
data.expressionSet.setVariable(data.paramIndex[j], atomParameters[i][j]);
data.param[j] = atomParameters[i][j];
for (int j = 1; j < (int) values.size(); j++) {
data.expressionSet.setVariable(data.valueIndex[j-1], values[j-1][i]);
data.value[j-1] = values[j-1][i];
data.dVdX[j] = 0.0;
data.dVdY[j] = 0.0;
data.dVdZ[j] = 0.0;
......@@ -599,7 +618,7 @@ void CpuCustomGBForce::calculateChainRuleForces(ThreadData& data, int numAtoms,
// Compute chain rule terms for derivatives with respect to parameters.
for (int i = data.firstAtom; i < data.lastAtom; i++)
for (int j = 0; j < data.valueIndex.size(); j++)
for (int j = 0; j < data.value.size(); j++)
for (int k = 0; k < dValuedParam[j].size(); k++)
data.energyParamDerivs[k] += dEdV[j][i]*dValuedParam[j][k][i];
}
......@@ -615,26 +634,25 @@ void CpuCustomGBForce::calculateOnePairChainRule(int atom1, int atom2, ThreadDat
getDeltaR(pos2, pos1, deltaR, r2, periodic, boxSize, invBoxSize);
if (cutoff && r2 >= cutoffDistance2)
return;
float r = sqrtf(r2);
data.r = sqrtf(r2);
// Record variables for evaluating expressions.
for (int i = 0; i < numParams; i++) {
data.expressionSet.setVariable(data.particleParamIndex[i*2], atomParameters[atom1][i]);
data.expressionSet.setVariable(data.particleParamIndex[i*2+1], atomParameters[atom2][i]);
data.expressionSet.setVariable(data.paramIndex[i], atomParameters[atom1][i]);
}
data.expressionSet.setVariable(data.valueIndex[0], values[0][atom1]);
data.expressionSet.setVariable(data.xindex, posq[4*atom1]);
data.expressionSet.setVariable(data.yindex, posq[4*atom1+1]);
data.expressionSet.setVariable(data.zindex, posq[4*atom1+2]);
data.expressionSet.setVariable(data.rindex, r);
data.expressionSet.setVariable(data.particleValueIndex[0], values[0][atom1]);
data.expressionSet.setVariable(data.particleValueIndex[1], values[0][atom2]);
data.particleParam[i*2] = atomParameters[atom1][i];
data.particleParam[i*2+1] = atomParameters[atom2][i];
data.param[i] = atomParameters[atom1][i];
}
data.value[0] = values[0][atom1];
data.x = posq[4*atom1];
data.y = posq[4*atom1+1];
data.z = posq[4*atom1+2];
data.particleValue[0] = values[0][atom1];
data.particleValue[1] = values[0][atom2];
// Evaluate the derivative of each parameter with respect to position and apply forces.
float rinv = 1/r;
float rinv = 1/data.r;
deltaR *= rinv;
fvec4 f1(0.0f), f2(0.0f);
if (!isExcluded || valueTypes[0] != CustomGBForce::ParticlePair) {
......@@ -644,7 +662,7 @@ void CpuCustomGBForce::calculateOnePairChainRule(int atom1, int atom2, ThreadDat
f2 -= deltaR*(dEdV[0][atom1]*data.dVdR2[0]);
}
for (int i = 1; i < (int) values.size(); i++) {
data.expressionSet.setVariable(data.valueIndex[i], values[i][atom1]);
data.value[i] = values[i][atom1];
data.dVdR1[i] = 0.0;
data.dVdR2[i] = 0.0;
for (int j = 0; j < i; j++) {
......
#include "../libraries/lepton/include/Lepton.h"
#include "openmm/internal/AssertionUtilities.h"
#include <iostream>
#include <limits>
#include <map>
using namespace Lepton;
using namespace OpenMM;
using namespace std;
#define ASSERT_EQUAL_TOL(expected, found, tol) {double _scale_ = std::fabs(expected) > 1.0 ? std::fabs(expected) : 1.0; if (!(std::fabs((expected)-(found))/_scale_ <= (tol))) throw exception();};
/**
* This is a custom function equal to f(x,y) = 2*x*y.
*/
......@@ -102,6 +102,18 @@ void verifyEvaluation(const string& expression, double x, double y, double expec
value = compiled.evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Try specifying memory locations for the compiled expression.
map<string, double*> variablePointers;
variablePointers["x"] = &x;
variablePointers["y"] = &y;
CompiledExpression compiled2 = parsed.createCompiledExpression();
compiled2.setVariableLocations(variablePointers);
value = compiled2.evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
ASSERT_EQUAL(&x, &compiled2.getVariableReference("x"));
ASSERT_EQUAL(&y, &compiled2.getVariableReference("y"));
// Make sure that variable renaming works.
variables.clear();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment