Commit 7ac8f663 authored by peastman's avatar peastman
Browse files

Merge pull request #207 from peastman/master

Optimizations to CPU version of GBSA
parents 412d42c9 563914a7
...@@ -160,6 +160,21 @@ public: ...@@ -160,6 +160,21 @@ public:
ivec4 operator==(ivec4 other) const { ivec4 operator==(ivec4 other) const {
return _mm_cmpeq_epi32(val, other); return _mm_cmpeq_epi32(val, other);
} }
ivec4 operator!=(ivec4 other) const {
return _mm_xor_si128(val==other, _mm_set1_epi32(0xFFFFFFFF));
}
ivec4 operator>(ivec4 other) const {
return _mm_cmpgt_epi32(val, other);
}
ivec4 operator<(ivec4 other) const {
return _mm_cmplt_epi32(val, other);
}
ivec4 operator>=(ivec4 other) const {
return _mm_xor_si128(_mm_cmplt_epi32(val, other), _mm_set1_epi32(0xFFFFFFFF));
}
ivec4 operator<=(ivec4 other) const {
return _mm_xor_si128(_mm_cmpgt_epi32(val, other), _mm_set1_epi32(0xFFFFFFFF));
}
operator fvec4() const; operator fvec4() const;
}; };
...@@ -230,6 +245,10 @@ static inline ivec4 abs(ivec4 v) { ...@@ -230,6 +245,10 @@ static inline ivec4 abs(ivec4 v) {
return ivec4(_mm_abs_epi32(v.val)); return ivec4(_mm_abs_epi32(v.val));
} }
static inline bool any(ivec4 v) {
return !_mm_test_all_zeros(v, _mm_set1_epi32(0xFFFFFFFF));
}
// Mathematical operators involving a scalar and a vector. // Mathematical operators involving a scalar and a vector.
static inline fvec4 operator+(float v1, fvec4 v2) { static inline fvec4 operator+(float v1, fvec4 v2) {
......
...@@ -106,11 +106,11 @@ private: ...@@ -106,11 +106,11 @@ private:
std::vector<std::vector<float> >* threadForce; std::vector<std::vector<float> >* threadForce;
bool includeEnergy; bool includeEnergy;
/** /**
* Compute the displacement and squared distance between two points, optionally using * Compute the displacement and squared distance between a collection of points, optionally using
* periodic boundary conditions. * periodic boundary conditions.
*/ */
void getDeltaR(const fvec4& posI, const fvec4& posJ, fvec4& deltaR, float& r2, bool periodic, const fvec4& boxSize, const fvec4& invBoxSize) const; void getDeltaR(const fvec4& posI, const fvec4& x, const fvec4& y, const fvec4& z, fvec4& dx, fvec4& dy, fvec4& dz, fvec4& r2, bool periodic, const fvec4& boxSize, const fvec4& invBoxSize) const;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -69,8 +69,8 @@ const std::vector<std::pair<float, float> >& CpuGBSAOBCForce::getParticleParamet ...@@ -69,8 +69,8 @@ const std::vector<std::pair<float, float> >& CpuGBSAOBCForce::getParticleParamet
void CpuGBSAOBCForce::setParticleParameters(const std::vector<std::pair<float, float> >& params) { void CpuGBSAOBCForce::setParticleParameters(const std::vector<std::pair<float, float> >& params) {
particleParams = params; particleParams = params;
bornRadii.resize(params.size()); bornRadii.resize(params.size()+3);
obcChain.resize(params.size()); obcChain.resize(params.size()+3);
} }
void CpuGBSAOBCForce::computeForce(const std::vector<float>& posq, vector<vector<float> >& threadForce, double* totalEnergy, ThreadPool& threads) { void CpuGBSAOBCForce::computeForce(const std::vector<float>& posq, vector<vector<float> >& threadForce, double* totalEnergy, ThreadPool& threads) {
...@@ -83,7 +83,7 @@ void CpuGBSAOBCForce::computeForce(const std::vector<float>& posq, vector<vector ...@@ -83,7 +83,7 @@ void CpuGBSAOBCForce::computeForce(const std::vector<float>& posq, vector<vector
threadEnergy.resize(numThreads); threadEnergy.resize(numThreads);
threadBornForces.resize(numThreads); threadBornForces.resize(numThreads);
for (int i = 0; i < numThreads; i++) for (int i = 0; i < numThreads; i++)
threadBornForces[i].resize(particleParams.size()); threadBornForces[i].resize(particleParams.size()+3);
// Signal the threads to start running and wait for them to finish. // Signal the threads to start running and wait for them to finish.
...@@ -121,45 +121,68 @@ void CpuGBSAOBCForce::threadComputeForce(ThreadPool& threads, int threadIndex) { ...@@ -121,45 +121,68 @@ void CpuGBSAOBCForce::threadComputeForce(ThreadPool& threads, int threadIndex) {
// Calculate Born radii // Calculate Born radii
for (int atomI = start; atomI < end; atomI++) { for (int blockStart = start; blockStart < end; blockStart += 4) {
fvec4 posI(posq+4*atomI); int numInBlock = min(4, end-blockStart);
float offsetRadiusI = particleParams[atomI].first; ivec4 blockAtomIndex(blockStart, blockStart+1, blockStart+2, blockStart+3);
float radiusIInverse = 1.0f/offsetRadiusI; float atomRadius[4], atomx[4], atomy[4], atomz[4];
float sum = 0.0f; int blockMask[4] = {0, 0, 0, 0};
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
atomRadius[i] = particleParams[atomIndex].first;
atomx[i] = posq[4*atomIndex];
atomy[i] = posq[4*atomIndex+1];
atomz[i] = posq[4*atomIndex+2];
blockMask[i] = 0xFFFFFFFF;
}
fvec4 offsetRadiusI(atomRadius);
fvec4 radiusIInverse = 1.0f/offsetRadiusI;
fvec4 x(atomx);
fvec4 y(atomy);
fvec4 z(atomz);
ivec4 mask(blockMask);
float sum[4] = {0.0f, 0.0f, 0.0f, 0.0f};
for (int atomJ = 0; atomJ < numParticles; atomJ++) { for (int atomJ = 0; atomJ < numParticles; atomJ++) {
if (atomJ != atomI) { fvec4 posJ(posq+4*atomJ);
fvec4 posJ(posq+4*atomJ); fvec4 dx, dy, dz, r2;
fvec4 deltaR; getDeltaR(posJ, x, y, z, dx, dy, dz, r2, periodic, boxSize, invBoxSize);
float r2; ivec4 include = mask & (blockAtomIndex != ivec4(atomJ));
getDeltaR(posI, posJ, deltaR, r2, periodic, boxSize, invBoxSize); if (cutoff)
if (cutoff && r2 >= cutoffDistance*cutoffDistance) include = include & (r2 < cutoffDistance*cutoffDistance);
continue; if (!any(include))
float r = sqrtf(r2); continue;
float scaledRadiusJ = particleParams[atomJ].second; fvec4 r = sqrt(r2);
float rScaledRadiusJ = r + scaledRadiusJ; float scaledRadiusJ = particleParams[atomJ].second;
if (offsetRadiusI < rScaledRadiusJ) { float scaledRadiusJ2 = scaledRadiusJ*scaledRadiusJ;
float rInverse = 1.0f/r; fvec4 rScaledRadiusJ = r + scaledRadiusJ;
float l_ij = 1.0f/(offsetRadiusI > fabs(r - scaledRadiusJ) ? offsetRadiusI : fabs(r - scaledRadiusJ)); include = include & (offsetRadiusI < rScaledRadiusJ);
float u_ij = 1.0f/rScaledRadiusJ; fvec4 l_ij = 1.0f/max(offsetRadiusI, abs(r-scaledRadiusJ));
float l_ij2 = l_ij*l_ij; fvec4 u_ij = 1.0f/rScaledRadiusJ;
float u_ij2 = u_ij*u_ij; fvec4 l_ij2 = l_ij*l_ij;
float ratio = log((u_ij/l_ij)); fvec4 u_ij2 = u_ij*u_ij;
float term = l_ij - u_ij + 0.25f*r*(u_ij2 - l_ij2) + (0.5f*rInverse*ratio) + (0.25f*scaledRadiusJ*scaledRadiusJ*rInverse)*(l_ij2 - u_ij2); fvec4 rInverse = 1.0f/r;
if (offsetRadiusI < (scaledRadiusJ - r)) fvec4 r2Inverse = rInverse*rInverse;
term += 2.0f*(radiusIInverse - l_ij); fvec4 ratio = u_ij/l_ij;
sum += term; fvec4 logRatio(logf(ratio[0]), logf(ratio[1]), logf(ratio[2]), logf(ratio[3]));
fvec4 term = l_ij - u_ij + 0.25f*r*(u_ij2 - l_ij2) + (0.5f*rInverse*logRatio) + (0.25f*scaledRadiusJ*scaledRadiusJ*rInverse)*(l_ij2 - u_ij2);
for (int j = 0; j < 4; j++) {
if (include[j]) {
sum[j] += term[j];
if (offsetRadiusI[j] < scaledRadiusJ-r[j])
sum[j] += 2.0f*(radiusIInverse[j]-l_ij[j]);
} }
} }
} }
sum *= 0.5f*offsetRadiusI; for (int i = 0; i < numInBlock; i++) {
float sum2 = sum*sum; int atomIndex = blockStart+i;
float sum3 = sum*sum2; sum[i] *= 0.5f*offsetRadiusI[i];
float tanhSum = tanh(alphaObc*sum - betaObc*sum2 + gammaObc*sum3); float sum2 = sum[i]*sum[i];
float radiusI = offsetRadiusI + dielectricOffset; float sum3 = sum[i]*sum2;
bornRadii[atomI] = 1.0f/(1.0f/offsetRadiusI - tanhSum/radiusI); float tanhSum = tanh(alphaObc*sum[i] - betaObc*sum2 + gammaObc*sum3);
obcChain[atomI] = offsetRadiusI*(alphaObc - 2.0f*betaObc*sum + 3.0f*gammaObc*sum2); float radiusI = offsetRadiusI[i] + dielectricOffset;
obcChain[atomI] = (1.0f - tanhSum*tanhSum)*obcChain[atomI]/radiusI; bornRadii[atomIndex] = 1.0f/(1.0f/offsetRadiusI[i] - tanhSum/radiusI);
obcChain[atomIndex] = offsetRadiusI[i]*(alphaObc - 2.0f*betaObc*sum[i] + 3.0f*gammaObc*sum2);
obcChain[atomIndex] = (1.0f - tanhSum*tanhSum)*obcChain[atomIndex]/radiusI;
}
} }
threads.syncThreads(); threads.syncThreads();
...@@ -193,93 +216,150 @@ void CpuGBSAOBCForce::threadComputeForce(ThreadPool& threads, int threadIndex) { ...@@ -193,93 +216,150 @@ void CpuGBSAOBCForce::threadComputeForce(ThreadPool& threads, int threadIndex) {
preFactor = ONE_4PI_EPS0*((1.0f/solventDielectric) - (1.0f/soluteDielectric)); preFactor = ONE_4PI_EPS0*((1.0f/solventDielectric) - (1.0f/soluteDielectric));
else else
preFactor = 0.0f; preFactor = 0.0f;
for (int atomI = start; atomI < end; atomI++) { for (int blockStart = start; blockStart < end; blockStart += 4) {
fvec4 posI(posq+4*atomI); int numInBlock = min(4, end-blockStart);
fvec4 forceI(0.0f); ivec4 blockAtomIndex(blockStart, blockStart+1, blockStart+2, blockStart+3);
float partialChargeI = preFactor*posI[3]; float atomCharge[4], atomx[4], atomy[4], atomz[4];
for (int atomJ = atomI; atomJ < numParticles; atomJ++) { int blockMask[4] = {0, 0, 0, 0};
fvec4 blockAtomForce[4];
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
atomx[i] = posq[4*atomIndex];
atomy[i] = posq[4*atomIndex+1];
atomz[i] = posq[4*atomIndex+2];
atomCharge[i] = preFactor*posq[4*atomIndex+3];
blockMask[i] = 0xFFFFFFFF;
blockAtomForce[i] = 0.0f;
}
fvec4 radii(&bornRadii[blockStart]);
fvec4 x(atomx);
fvec4 y(atomy);
fvec4 z(atomz);
fvec4 partialChargeI(atomCharge);
ivec4 mask(blockMask);
for (int atomJ = blockStart; atomJ < numParticles; atomJ++) {
fvec4 posJ(posq+4*atomJ); fvec4 posJ(posq+4*atomJ);
fvec4 deltaR; fvec4 dx, dy, dz, r2;
float r2; getDeltaR(posJ, x, y, z, dx, dy, dz, r2, periodic, boxSize, invBoxSize);
getDeltaR(posI, posJ, deltaR, r2, periodic, boxSize, invBoxSize); ivec4 include = mask & (blockAtomIndex <= ivec4(atomJ));
if (cutoff && r2 >= cutoffDistance*cutoffDistance) if (cutoff)
include = include & (r2 < cutoffDistance*cutoffDistance);
if (!any(include))
continue; continue;
float r = sqrtf(r2); fvec4 r = sqrt(r2);
float alpha2_ij = bornRadii[atomI]*bornRadii[atomJ]; fvec4 alpha2_ij = radii*bornRadii[atomJ];
float D_ij = r2/(4.0f*alpha2_ij); fvec4 D_ij = r2/(4.0f*alpha2_ij);
float expTerm = exp(-D_ij); fvec4 expTerm(exp(-D_ij[0]), exp(-D_ij[1]), exp(-D_ij[2]), exp(-D_ij[3]));
float denominator2 = r2 + alpha2_ij*expTerm; fvec4 denominator2 = r2 + alpha2_ij*expTerm;
float denominator = sqrt(denominator2); fvec4 denominator = sqrt(denominator2);
float Gpol = (partialChargeI*posJ[3])/denominator; fvec4 Gpol = (partialChargeI*posJ[3])/denominator;
float dGpol_dr = -Gpol*(1.0f - 0.25f*expTerm)/denominator2; fvec4 dGpol_dr = -Gpol*(1.0f - 0.25f*expTerm)/denominator2;
float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f + D_ij)/denominator2; fvec4 dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f + D_ij)/denominator2;
float termEnergy = Gpol; fvec4 result[4] = {dx*dGpol_dr, dy*dGpol_dr, dz*dGpol_dr, 0.0f};
if (atomI != atomJ) { transpose(result[0], result[1], result[2], result[3]);
if (cutoff) fvec4 atomForce(forces+4*atomJ);
termEnergy -= partialChargeI*posJ[3]/cutoffDistance; for (int j = 0; j < 4; j++) {
bornForces[atomJ] += dGpol_dalpha2_ij*bornRadii[atomI]; if (include[j]) {
fvec4 result = deltaR*dGpol_dr; float termEnergy = Gpol[j];
forceI += result; if (blockStart+j != atomJ) {
(fvec4(forces+4*atomJ)-result).store(forces+4*atomJ); if (cutoff)
termEnergy -= partialChargeI[j]*posJ[3]/cutoffDistance;
bornForces[atomJ] += dGpol_dalpha2_ij[j]*radii[j];
blockAtomForce[j] -= result[j];
(fvec4(forces+4*atomJ)+result[j]).store(forces+4*atomJ);
}
else
termEnergy *= 0.5f;
energy += termEnergy;
bornForces[blockStart+j] += dGpol_dalpha2_ij[j]*bornRadii[atomJ];
}
} }
else
termEnergy *= 0.5f;
energy += termEnergy;
bornForces[atomI] += dGpol_dalpha2_ij*bornRadii[atomJ];
} }
(fvec4(forces+4*atomI)+forceI).store(forces+4*atomI); for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
(fvec4(forces+4*atomIndex)+blockAtomForce[i]).store(forces+4*atomIndex);
}
} }
threads.syncThreads(); threads.syncThreads();
// Second loop of Born energy computation. // Second loop of Born energy computation.
for (int atomI = start; atomI < end; atomI++) { for (int blockStart = start; blockStart < end; blockStart += 4) {
float bornForce = 0; fvec4 bornForce(0.0f);
for (int i = 0; i < numThreads; i++) for (int i = 0; i < numThreads; i++)
bornForce += threadBornForces[i][atomI]; bornForce += fvec4(&threadBornForces[i][blockStart]);
bornForce *= bornRadii[atomI]*bornRadii[atomI]*obcChain[atomI]; fvec4 radii(&bornRadii[blockStart]);
float offsetRadiusI = particleParams[atomI].first; bornForce *= radii*radii*fvec4(&obcChain[blockStart]);
fvec4 posI(posq+4*atomI); int numInBlock = min(4, end-blockStart);
fvec4 forceI(0.0f); ivec4 blockAtomIndex(blockStart, blockStart+1, blockStart+2, blockStart+3);
float atomRadius[4], atomx[4], atomy[4], atomz[4];
int blockMask[4] = {0, 0, 0, 0};
fvec4 blockAtomForce[4];
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
atomRadius[i] = particleParams[atomIndex].first;
atomx[i] = posq[4*atomIndex];
atomy[i] = posq[4*atomIndex+1];
atomz[i] = posq[4*atomIndex+2];
blockMask[i] = 0xFFFFFFFF;
blockAtomForce[i] = 0.0f;
}
fvec4 offsetRadiusI(atomRadius);
fvec4 x(atomx);
fvec4 y(atomy);
fvec4 z(atomz);
ivec4 mask(blockMask);
for (int atomJ = 0; atomJ < numParticles; atomJ++) { for (int atomJ = 0; atomJ < numParticles; atomJ++) {
if (atomJ != atomI) { fvec4 posJ(posq+4*atomJ);
fvec4 posJ(posq+4*atomJ); fvec4 dx, dy, dz, r2;
fvec4 deltaR; getDeltaR(posJ, x, y, z, dx, dy, dz, r2, periodic, boxSize, invBoxSize);
float r2; ivec4 include = mask & (blockAtomIndex != ivec4(atomJ));
getDeltaR(posI, posJ, deltaR, r2, periodic, boxSize, invBoxSize); if (cutoff)
if (cutoff && r2 >= cutoffDistance*cutoffDistance) include = include & (r2 < cutoffDistance*cutoffDistance);
continue; if (!any(include))
float r = sqrtf(r2); continue;
float scaledRadiusJ = particleParams[atomJ].second; fvec4 r = sqrt(r2);
float scaledRadiusJ2 = scaledRadiusJ*scaledRadiusJ; float scaledRadiusJ = particleParams[atomJ].second;
float rScaledRadiusJ = r + scaledRadiusJ; float scaledRadiusJ2 = scaledRadiusJ*scaledRadiusJ;
if (offsetRadiusI < rScaledRadiusJ) { fvec4 rScaledRadiusJ = r + scaledRadiusJ;
float l_ij = 1.0f/(offsetRadiusI > fabs(r - scaledRadiusJ) ? offsetRadiusI : fabs(r - scaledRadiusJ)); include = include & (offsetRadiusI < rScaledRadiusJ);
float u_ij = 1.0f/rScaledRadiusJ; fvec4 l_ij = 1.0f/max(offsetRadiusI, abs(r-scaledRadiusJ));
float l_ij2 = l_ij*l_ij; fvec4 u_ij = 1.0f/rScaledRadiusJ;
float u_ij2 = u_ij*u_ij; fvec4 l_ij2 = l_ij*l_ij;
float rInverse = 1.0f/r; fvec4 u_ij2 = u_ij*u_ij;
float r2Inverse = rInverse*rInverse; fvec4 rInverse = 1.0f/r;
float t3 = 0.125f*(1.0f + scaledRadiusJ2*r2Inverse)*(l_ij2 - u_ij2) + 0.25f*log(u_ij/l_ij)*r2Inverse; fvec4 r2Inverse = rInverse*rInverse;
float de = bornForce*t3*rInverse; fvec4 ratio = u_ij/l_ij;
fvec4 result = deltaR*de; fvec4 logRatio(logf(ratio[0]), logf(ratio[1]), logf(ratio[2]), logf(ratio[3]));
forceI -= result; fvec4 t3 = 0.125f*(1.0f + scaledRadiusJ2*r2Inverse)*(l_ij2 - u_ij2) + 0.25f*logRatio*r2Inverse;
(fvec4(forces+4*atomJ)+result).store(forces+4*atomJ); fvec4 de = bornForce*t3*rInverse;
fvec4 result[4] = {dx*de, dy*de, dz*de, 0.0f};
transpose(result[0], result[1], result[2], result[3]);
fvec4 atomForce(forces+4*atomJ);
for (int j = 0; j < 4; j++) {
if (include[j]) {
blockAtomForce[j] += result[j];
atomForce -= result[j];
} }
} }
atomForce.store(forces+4*atomJ);
}
for (int i = 0; i < numInBlock; i++) {
int atomIndex = blockStart+i;
(fvec4(forces+4*atomIndex)+blockAtomForce[i]).store(forces+4*atomIndex);
} }
(fvec4(forces+4*atomI)+forceI).store(forces+4*atomI);
} }
threadEnergy[threadIndex] = energy; threadEnergy[threadIndex] = energy;
} }
void CpuGBSAOBCForce::getDeltaR(const fvec4& posI, const fvec4& posJ, fvec4& deltaR, float& r2, bool periodic, const fvec4& boxSize, const fvec4& invBoxSize) const { void CpuGBSAOBCForce::getDeltaR(const fvec4& posI, const fvec4& x, const fvec4& y, const fvec4& z, fvec4& dx, fvec4& dy, fvec4& dz, fvec4& r2, bool periodic, const fvec4& boxSize, const fvec4& invBoxSize) const {
deltaR = posJ-posI; dx = x-posI[0];
dy = y-posI[1];
dz = z-posI[2];
if (periodic) { if (periodic) {
fvec4 base = round(deltaR*invBoxSize)*boxSize; dx -= round(dx*invBoxSize[0])*boxSize[0];
deltaR = deltaR-base; dy -= round(dy*invBoxSize[1])*boxSize[1];
dz -= round(dz*invBoxSize[2])*boxSize[2];
} }
r2 = dot3(deltaR, deltaR); r2 = dx*dx + dy*dy + dz*dz;
} }
...@@ -223,7 +223,6 @@ void testForce(int numParticles, NonbondedForce::NonbondedMethod method, GBSAOBC ...@@ -223,7 +223,6 @@ void testForce(int numParticles, NonbondedForce::NonbondedMethod method, GBSAOBC
int main() { int main() {
try { try {
testForce(729, NonbondedForce::CutoffNonPeriodic, GBSAOBCForce::CutoffNonPeriodic);
testSingleParticle(); testSingleParticle();
testCutoffAndPeriodic(); testCutoffAndPeriodic();
for (int i = 5; i < 11; i++) { for (int i = 5; i < 11; i++) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment