Commit 563914a7 authored by peastman's avatar peastman
Browse files

Optimizations to CPU version of GBSA

parent c6b335e7
......@@ -160,6 +160,21 @@ public:
ivec4 operator==(ivec4 other) const {
return _mm_cmpeq_epi32(val, other);
}
ivec4 operator!=(ivec4 other) const {
return _mm_xor_si128(val==other, _mm_set1_epi32(0xFFFFFFFF));
}
ivec4 operator>(ivec4 other) const {
return _mm_cmpgt_epi32(val, other);
}
ivec4 operator<(ivec4 other) const {
return _mm_cmplt_epi32(val, other);
}
ivec4 operator>=(ivec4 other) const {
return _mm_xor_si128(_mm_cmplt_epi32(val, other), _mm_set1_epi32(0xFFFFFFFF));
}
ivec4 operator<=(ivec4 other) const {
return _mm_xor_si128(_mm_cmpgt_epi32(val, other), _mm_set1_epi32(0xFFFFFFFF));
}
operator fvec4() const;
};
......@@ -230,6 +245,10 @@ static inline ivec4 abs(ivec4 v) {
return ivec4(_mm_abs_epi32(v.val));
}
static inline bool any(ivec4 v) {
return !_mm_test_all_zeros(v, _mm_set1_epi32(0xFFFFFFFF));
}
// Mathematical operators involving a scalar and a vector.
static inline fvec4 operator+(float v1, fvec4 v2) {
......
......@@ -107,10 +107,10 @@ private:
bool includeEnergy;
/**
* Compute the displacement and squared distance between two points, optionally using
* Compute the displacement and squared distance between a collection of points, optionally using
* periodic boundary conditions.
*/
void getDeltaR(const fvec4& posI, const fvec4& posJ, fvec4& deltaR, float& r2, bool periodic, const fvec4& boxSize, const fvec4& invBoxSize) const;
void getDeltaR(const fvec4& posI, const fvec4& x, const fvec4& y, const fvec4& z, fvec4& dx, fvec4& dy, fvec4& dz, fvec4& r2, bool periodic, const fvec4& boxSize, const fvec4& invBoxSize) const;
};
} // namespace OpenMM
......
......@@ -69,8 +69,8 @@ const std::vector<std::pair<float, float> >& CpuGBSAOBCForce::getParticleParamet
void CpuGBSAOBCForce::setParticleParameters(const std::vector<std::pair<float, float> >& params) {
particleParams = params;
bornRadii.resize(params.size());
obcChain.resize(params.size());
bornRadii.resize(params.size()+3);
obcChain.resize(params.size()+3);
}
void CpuGBSAOBCForce::computeForce(const std::vector<float>& posq, vector<vector<float> >& threadForce, double* totalEnergy, ThreadPool& threads) {
......@@ -83,7 +83,7 @@ void CpuGBSAOBCForce::computeForce(const std::vector<float>& posq, vector<vector
threadEnergy.resize(numThreads);
threadBornForces.resize(numThreads);
for (int i = 0; i < numThreads; i++)
threadBornForces[i].resize(particleParams.size());
threadBornForces[i].resize(particleParams.size()+3);
// Signal the threads to start running and wait for them to finish.
......@@ -121,45 +121,68 @@ void CpuGBSAOBCForce::threadComputeForce(ThreadPool& threads, int threadIndex) {
// Calculate Born radii
for (int atomI = start; atomI < end; atomI++) {
fvec4 posI(posq+4*atomI);
float offsetRadiusI = particleParams[atomI].first;
float radiusIInverse = 1.0f/offsetRadiusI;
float sum = 0.0f;
for (int blockStart = start; blockStart < end; blockStart += 4) {
int numInBlock = min(4, end-blockStart);
ivec4 blockAtomIndex(blockStart, blockStart+1, blockStart+2, blockStart+3);
float atomRadius[4], atomx[4], atomy[4], atomz[4];
int blockMask[4] = {0, 0, 0, 0};
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
atomRadius[i] = particleParams[atomIndex].first;
atomx[i] = posq[4*atomIndex];
atomy[i] = posq[4*atomIndex+1];
atomz[i] = posq[4*atomIndex+2];
blockMask[i] = 0xFFFFFFFF;
}
fvec4 offsetRadiusI(atomRadius);
fvec4 radiusIInverse = 1.0f/offsetRadiusI;
fvec4 x(atomx);
fvec4 y(atomy);
fvec4 z(atomz);
ivec4 mask(blockMask);
float sum[4] = {0.0f, 0.0f, 0.0f, 0.0f};
for (int atomJ = 0; atomJ < numParticles; atomJ++) {
if (atomJ != atomI) {
fvec4 posJ(posq+4*atomJ);
fvec4 deltaR;
float r2;
getDeltaR(posI, posJ, deltaR, r2, periodic, boxSize, invBoxSize);
if (cutoff && r2 >= cutoffDistance*cutoffDistance)
fvec4 dx, dy, dz, r2;
getDeltaR(posJ, x, y, z, dx, dy, dz, r2, periodic, boxSize, invBoxSize);
ivec4 include = mask & (blockAtomIndex != ivec4(atomJ));
if (cutoff)
include = include & (r2 < cutoffDistance*cutoffDistance);
if (!any(include))
continue;
float r = sqrtf(r2);
fvec4 r = sqrt(r2);
float scaledRadiusJ = particleParams[atomJ].second;
float rScaledRadiusJ = r + scaledRadiusJ;
if (offsetRadiusI < rScaledRadiusJ) {
float rInverse = 1.0f/r;
float l_ij = 1.0f/(offsetRadiusI > fabs(r - scaledRadiusJ) ? offsetRadiusI : fabs(r - scaledRadiusJ));
float u_ij = 1.0f/rScaledRadiusJ;
float l_ij2 = l_ij*l_ij;
float u_ij2 = u_ij*u_ij;
float ratio = log((u_ij/l_ij));
float term = l_ij - u_ij + 0.25f*r*(u_ij2 - l_ij2) + (0.5f*rInverse*ratio) + (0.25f*scaledRadiusJ*scaledRadiusJ*rInverse)*(l_ij2 - u_ij2);
if (offsetRadiusI < (scaledRadiusJ - r))
term += 2.0f*(radiusIInverse - l_ij);
sum += term;
float scaledRadiusJ2 = scaledRadiusJ*scaledRadiusJ;
fvec4 rScaledRadiusJ = r + scaledRadiusJ;
include = include & (offsetRadiusI < rScaledRadiusJ);
fvec4 l_ij = 1.0f/max(offsetRadiusI, abs(r-scaledRadiusJ));
fvec4 u_ij = 1.0f/rScaledRadiusJ;
fvec4 l_ij2 = l_ij*l_ij;
fvec4 u_ij2 = u_ij*u_ij;
fvec4 rInverse = 1.0f/r;
fvec4 r2Inverse = rInverse*rInverse;
fvec4 ratio = u_ij/l_ij;
fvec4 logRatio(logf(ratio[0]), logf(ratio[1]), logf(ratio[2]), logf(ratio[3]));
fvec4 term = l_ij - u_ij + 0.25f*r*(u_ij2 - l_ij2) + (0.5f*rInverse*logRatio) + (0.25f*scaledRadiusJ*scaledRadiusJ*rInverse)*(l_ij2 - u_ij2);
for (int j = 0; j < 4; j++) {
if (include[j]) {
sum[j] += term[j];
if (offsetRadiusI[j] < scaledRadiusJ-r[j])
sum[j] += 2.0f*(radiusIInverse[j]-l_ij[j]);
}
}
}
sum *= 0.5f*offsetRadiusI;
float sum2 = sum*sum;
float sum3 = sum*sum2;
float tanhSum = tanh(alphaObc*sum - betaObc*sum2 + gammaObc*sum3);
float radiusI = offsetRadiusI + dielectricOffset;
bornRadii[atomI] = 1.0f/(1.0f/offsetRadiusI - tanhSum/radiusI);
obcChain[atomI] = offsetRadiusI*(alphaObc - 2.0f*betaObc*sum + 3.0f*gammaObc*sum2);
obcChain[atomI] = (1.0f - tanhSum*tanhSum)*obcChain[atomI]/radiusI;
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
sum[i] *= 0.5f*offsetRadiusI[i];
float sum2 = sum[i]*sum[i];
float sum3 = sum[i]*sum2;
float tanhSum = tanh(alphaObc*sum[i] - betaObc*sum2 + gammaObc*sum3);
float radiusI = offsetRadiusI[i] + dielectricOffset;
bornRadii[atomIndex] = 1.0f/(1.0f/offsetRadiusI[i] - tanhSum/radiusI);
obcChain[atomIndex] = offsetRadiusI[i]*(alphaObc - 2.0f*betaObc*sum[i] + 3.0f*gammaObc*sum2);
obcChain[atomIndex] = (1.0f - tanhSum*tanhSum)*obcChain[atomIndex]/radiusI;
}
}
threads.syncThreads();
......@@ -193,93 +216,150 @@ void CpuGBSAOBCForce::threadComputeForce(ThreadPool& threads, int threadIndex) {
preFactor = ONE_4PI_EPS0*((1.0f/solventDielectric) - (1.0f/soluteDielectric));
else
preFactor = 0.0f;
for (int atomI = start; atomI < end; atomI++) {
fvec4 posI(posq+4*atomI);
fvec4 forceI(0.0f);
float partialChargeI = preFactor*posI[3];
for (int atomJ = atomI; atomJ < numParticles; atomJ++) {
for (int blockStart = start; blockStart < end; blockStart += 4) {
int numInBlock = min(4, end-blockStart);
ivec4 blockAtomIndex(blockStart, blockStart+1, blockStart+2, blockStart+3);
float atomCharge[4], atomx[4], atomy[4], atomz[4];
int blockMask[4] = {0, 0, 0, 0};
fvec4 blockAtomForce[4];
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
atomx[i] = posq[4*atomIndex];
atomy[i] = posq[4*atomIndex+1];
atomz[i] = posq[4*atomIndex+2];
atomCharge[i] = preFactor*posq[4*atomIndex+3];
blockMask[i] = 0xFFFFFFFF;
blockAtomForce[i] = 0.0f;
}
fvec4 radii(&bornRadii[blockStart]);
fvec4 x(atomx);
fvec4 y(atomy);
fvec4 z(atomz);
fvec4 partialChargeI(atomCharge);
ivec4 mask(blockMask);
for (int atomJ = blockStart; atomJ < numParticles; atomJ++) {
fvec4 posJ(posq+4*atomJ);
fvec4 deltaR;
float r2;
getDeltaR(posI, posJ, deltaR, r2, periodic, boxSize, invBoxSize);
if (cutoff && r2 >= cutoffDistance*cutoffDistance)
fvec4 dx, dy, dz, r2;
getDeltaR(posJ, x, y, z, dx, dy, dz, r2, periodic, boxSize, invBoxSize);
ivec4 include = mask & (blockAtomIndex <= ivec4(atomJ));
if (cutoff)
include = include & (r2 < cutoffDistance*cutoffDistance);
if (!any(include))
continue;
float r = sqrtf(r2);
float alpha2_ij = bornRadii[atomI]*bornRadii[atomJ];
float D_ij = r2/(4.0f*alpha2_ij);
float expTerm = exp(-D_ij);
float denominator2 = r2 + alpha2_ij*expTerm;
float denominator = sqrt(denominator2);
float Gpol = (partialChargeI*posJ[3])/denominator;
float dGpol_dr = -Gpol*(1.0f - 0.25f*expTerm)/denominator2;
float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f + D_ij)/denominator2;
float termEnergy = Gpol;
if (atomI != atomJ) {
fvec4 r = sqrt(r2);
fvec4 alpha2_ij = radii*bornRadii[atomJ];
fvec4 D_ij = r2/(4.0f*alpha2_ij);
fvec4 expTerm(exp(-D_ij[0]), exp(-D_ij[1]), exp(-D_ij[2]), exp(-D_ij[3]));
fvec4 denominator2 = r2 + alpha2_ij*expTerm;
fvec4 denominator = sqrt(denominator2);
fvec4 Gpol = (partialChargeI*posJ[3])/denominator;
fvec4 dGpol_dr = -Gpol*(1.0f - 0.25f*expTerm)/denominator2;
fvec4 dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f + D_ij)/denominator2;
fvec4 result[4] = {dx*dGpol_dr, dy*dGpol_dr, dz*dGpol_dr, 0.0f};
transpose(result[0], result[1], result[2], result[3]);
fvec4 atomForce(forces+4*atomJ);
for (int j = 0; j < 4; j++) {
if (include[j]) {
float termEnergy = Gpol[j];
if (blockStart+j != atomJ) {
if (cutoff)
termEnergy -= partialChargeI*posJ[3]/cutoffDistance;
bornForces[atomJ] += dGpol_dalpha2_ij*bornRadii[atomI];
fvec4 result = deltaR*dGpol_dr;
forceI += result;
(fvec4(forces+4*atomJ)-result).store(forces+4*atomJ);
termEnergy -= partialChargeI[j]*posJ[3]/cutoffDistance;
bornForces[atomJ] += dGpol_dalpha2_ij[j]*radii[j];
blockAtomForce[j] -= result[j];
(fvec4(forces+4*atomJ)+result[j]).store(forces+4*atomJ);
}
else
termEnergy *= 0.5f;
energy += termEnergy;
bornForces[atomI] += dGpol_dalpha2_ij*bornRadii[atomJ];
bornForces[blockStart+j] += dGpol_dalpha2_ij[j]*bornRadii[atomJ];
}
}
}
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
(fvec4(forces+4*atomIndex)+blockAtomForce[i]).store(forces+4*atomIndex);
}
(fvec4(forces+4*atomI)+forceI).store(forces+4*atomI);
}
threads.syncThreads();
// Second loop of Born energy computation.
for (int atomI = start; atomI < end; atomI++) {
float bornForce = 0;
for (int blockStart = start; blockStart < end; blockStart += 4) {
fvec4 bornForce(0.0f);
for (int i = 0; i < numThreads; i++)
bornForce += threadBornForces[i][atomI];
bornForce *= bornRadii[atomI]*bornRadii[atomI]*obcChain[atomI];
float offsetRadiusI = particleParams[atomI].first;
fvec4 posI(posq+4*atomI);
fvec4 forceI(0.0f);
bornForce += fvec4(&threadBornForces[i][blockStart]);
fvec4 radii(&bornRadii[blockStart]);
bornForce *= radii*radii*fvec4(&obcChain[blockStart]);
int numInBlock = min(4, end-blockStart);
ivec4 blockAtomIndex(blockStart, blockStart+1, blockStart+2, blockStart+3);
float atomRadius[4], atomx[4], atomy[4], atomz[4];
int blockMask[4] = {0, 0, 0, 0};
fvec4 blockAtomForce[4];
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
atomRadius[i] = particleParams[atomIndex].first;
atomx[i] = posq[4*atomIndex];
atomy[i] = posq[4*atomIndex+1];
atomz[i] = posq[4*atomIndex+2];
blockMask[i] = 0xFFFFFFFF;
blockAtomForce[i] = 0.0f;
}
fvec4 offsetRadiusI(atomRadius);
fvec4 x(atomx);
fvec4 y(atomy);
fvec4 z(atomz);
ivec4 mask(blockMask);
for (int atomJ = 0; atomJ < numParticles; atomJ++) {
if (atomJ != atomI) {
fvec4 posJ(posq+4*atomJ);
fvec4 deltaR;
float r2;
getDeltaR(posI, posJ, deltaR, r2, periodic, boxSize, invBoxSize);
if (cutoff && r2 >= cutoffDistance*cutoffDistance)
fvec4 dx, dy, dz, r2;
getDeltaR(posJ, x, y, z, dx, dy, dz, r2, periodic, boxSize, invBoxSize);
ivec4 include = mask & (blockAtomIndex != ivec4(atomJ));
if (cutoff)
include = include & (r2 < cutoffDistance*cutoffDistance);
if (!any(include))
continue;
float r = sqrtf(r2);
fvec4 r = sqrt(r2);
float scaledRadiusJ = particleParams[atomJ].second;
float scaledRadiusJ2 = scaledRadiusJ*scaledRadiusJ;
float rScaledRadiusJ = r + scaledRadiusJ;
if (offsetRadiusI < rScaledRadiusJ) {
float l_ij = 1.0f/(offsetRadiusI > fabs(r - scaledRadiusJ) ? offsetRadiusI : fabs(r - scaledRadiusJ));
float u_ij = 1.0f/rScaledRadiusJ;
float l_ij2 = l_ij*l_ij;
float u_ij2 = u_ij*u_ij;
float rInverse = 1.0f/r;
float r2Inverse = rInverse*rInverse;
float t3 = 0.125f*(1.0f + scaledRadiusJ2*r2Inverse)*(l_ij2 - u_ij2) + 0.25f*log(u_ij/l_ij)*r2Inverse;
float de = bornForce*t3*rInverse;
fvec4 result = deltaR*de;
forceI -= result;
(fvec4(forces+4*atomJ)+result).store(forces+4*atomJ);
fvec4 rScaledRadiusJ = r + scaledRadiusJ;
include = include & (offsetRadiusI < rScaledRadiusJ);
fvec4 l_ij = 1.0f/max(offsetRadiusI, abs(r-scaledRadiusJ));
fvec4 u_ij = 1.0f/rScaledRadiusJ;
fvec4 l_ij2 = l_ij*l_ij;
fvec4 u_ij2 = u_ij*u_ij;
fvec4 rInverse = 1.0f/r;
fvec4 r2Inverse = rInverse*rInverse;
fvec4 ratio = u_ij/l_ij;
fvec4 logRatio(logf(ratio[0]), logf(ratio[1]), logf(ratio[2]), logf(ratio[3]));
fvec4 t3 = 0.125f*(1.0f + scaledRadiusJ2*r2Inverse)*(l_ij2 - u_ij2) + 0.25f*logRatio*r2Inverse;
fvec4 de = bornForce*t3*rInverse;
fvec4 result[4] = {dx*de, dy*de, dz*de, 0.0f};
transpose(result[0], result[1], result[2], result[3]);
fvec4 atomForce(forces+4*atomJ);
for (int j = 0; j < 4; j++) {
if (include[j]) {
blockAtomForce[j] += result[j];
atomForce -= result[j];
}
}
atomForce.store(forces+4*atomJ);
}
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
(fvec4(forces+4*atomIndex)+blockAtomForce[i]).store(forces+4*atomIndex);
}
(fvec4(forces+4*atomI)+forceI).store(forces+4*atomI);
}
threadEnergy[threadIndex] = energy;
}
void CpuGBSAOBCForce::getDeltaR(const fvec4& posI, const fvec4& posJ, fvec4& deltaR, float& r2, bool periodic, const fvec4& boxSize, const fvec4& invBoxSize) const {
deltaR = posJ-posI;
void CpuGBSAOBCForce::getDeltaR(const fvec4& posI, const fvec4& x, const fvec4& y, const fvec4& z, fvec4& dx, fvec4& dy, fvec4& dz, fvec4& r2, bool periodic, const fvec4& boxSize, const fvec4& invBoxSize) const {
dx = x-posI[0];
dy = y-posI[1];
dz = z-posI[2];
if (periodic) {
fvec4 base = round(deltaR*invBoxSize)*boxSize;
deltaR = deltaR-base;
dx -= round(dx*invBoxSize[0])*boxSize[0];
dy -= round(dy*invBoxSize[1])*boxSize[1];
dz -= round(dz*invBoxSize[2])*boxSize[2];
}
r2 = dot3(deltaR, deltaR);
r2 = dx*dx + dy*dy + dz*dz;
}
......@@ -223,7 +223,6 @@ void testForce(int numParticles, NonbondedForce::NonbondedMethod method, GBSAOBC
int main() {
try {
testForce(729, NonbondedForce::CutoffNonPeriodic, GBSAOBCForce::CutoffNonPeriodic);
testSingleParticle();
testCutoffAndPeriodic();
for (int i = 5; i < 11; i++) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment