Commit 98b09e6f authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimization for ATI: replaced if's with predication

parent 2c6cff12
...@@ -37,10 +37,11 @@ if (!isExcluded || needCorrection) { ...@@ -37,10 +37,11 @@ if (!isExcluded || needCorrection) {
dEdR += tempForce*invR*invR; dEdR += tempForce*invR*invR;
} }
#else #else
{
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
if (!isExcluded && r2 < CUTOFF_SQUARED) { unsigned int includeInteraction = (!isExcluded && r2 < CUTOFF_SQUARED);
#else #else
if (!isExcluded) { unsigned int includeInteraction = (!isExcluded);
#endif #endif
float tempForce = 0.0f; float tempForce = 0.0f;
#if HAS_LENNARD_JONES #if HAS_LENNARD_JONES
...@@ -50,19 +51,19 @@ if (!isExcluded) { ...@@ -50,19 +51,19 @@ if (!isExcluded) {
float sig6 = sig2*sig2*sig2; float sig6 = sig2*sig2*sig2;
float eps = sigmaEpsilon1.y*sigmaEpsilon2.y; float eps = sigmaEpsilon1.y*sigmaEpsilon2.y;
tempForce = eps*(12.0f*sig6 - 6.0f)*sig6; tempForce = eps*(12.0f*sig6 - 6.0f)*sig6;
tempEnergy += eps*(sig6 - 1.0f)*sig6; tempEnergy += select(0.0f, eps*(sig6 - 1.0f)*sig6, includeInteraction);
#endif #endif
#if HAS_COULOMB #if HAS_COULOMB
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
const float prefactor = 138.935456f*posq1.w*posq2.w; const float prefactor = 138.935456f*posq1.w*posq2.w;
tempForce += prefactor*(invR - 2.0f*REACTION_FIELD_K*r2); tempForce += prefactor*(invR - 2.0f*REACTION_FIELD_K*r2);
tempEnergy += prefactor*(invR + REACTION_FIELD_K*r2 - REACTION_FIELD_C); tempEnergy += select(0.0f, prefactor*(invR + REACTION_FIELD_K*r2 - REACTION_FIELD_C), includeInteraction);
#else #else
const float prefactor = 138.935456f*posq1.w*posq2.w*invR; const float prefactor = 138.935456f*posq1.w*posq2.w*invR;
tempForce += prefactor; tempForce += prefactor;
tempEnergy += prefactor; tempEnergy += select(0.0f, prefactor, includeInteraction);
#endif #endif
#endif #endif
dEdR += tempForce*invR*invR; dEdR += select(0.0f, tempForce*invR*invR, includeInteraction);
} }
#endif #endif
\ No newline at end of file
#ifdef USE_CUTOFF {
if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2 && r2 < CUTOFF_SQUARED) {
#else
if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2) {
#endif
float invRSquared = RECIP(r2); float invRSquared = RECIP(r2);
float rScaledRadiusJ = r+obcParams2.y; float rScaledRadiusJ = r+obcParams2.y;
float rScaledRadiusI = r+obcParams1.y; float rScaledRadiusI = r+obcParams1.y;
...@@ -24,6 +20,12 @@ if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2) { ...@@ -24,6 +20,12 @@ if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2) {
t1I *= invR; t1I *= invR;
float term1 = 0.125f*(1.0f+obcParams2.y*obcParams2.y*invRSquared)*t3J + 0.25f*t1J*invRSquared; float term1 = 0.125f*(1.0f+obcParams2.y*obcParams2.y*invRSquared)*t3J + 0.25f*t1J*invRSquared;
float term2 = 0.125f*(1.0f+obcParams1.y*obcParams1.y*invRSquared)*t3I + 0.25f*t1I*invRSquared; float term2 = 0.125f*(1.0f+obcParams1.y*obcParams1.y*invRSquared)*t3I + 0.25f*t1I*invRSquared;
dEdR += (obcParams1.x < rScaledRadiusJ ? bornForce1*term1 : 0.0f); float tempdEdR = select(0.0f, bornForce1*term1, obcParams1.x < rScaledRadiusJ);
dEdR += (obcParams2.x < rScaledRadiusJ ? bornForce2*term2 : 0.0f); tempdEdR += select(0.0f, bornForce2*term2, obcParams2.x < rScaledRadiusJ);
#ifdef USE_CUTOFF
unsigned int includeInteraction = (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2 && r2 < CUTOFF_SQUARED);
#else
unsigned int includeInteraction = (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2);
#endif
dEdR += select(0.0f, tempdEdR, includeInteraction);
} }
...@@ -61,27 +61,23 @@ void computeBornSum(__global float* global_bornSum, __global float4* posq, __glo ...@@ -61,27 +61,23 @@ void computeBornSum(__global float* global_bornSum, __global float4* posq, __glo
delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z; delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
float invR = RSQRT(r2);
float r = RECIP(invR);
float2 params2 = (float2) (localData[baseLocalAtom+j].radius, localData[baseLocalAtom+j].scaledRadius);
float rScaledRadiusJ = r+params2.y;
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
if (atom1 < NUM_ATOMS && y+baseLocalAtom+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) { unsigned int includeInteraction = (atom1 < NUM_ATOMS && y+baseLocalAtom+j < NUM_ATOMS && r2 < CUTOFF_SQUARED && (j+baseLocalAtom != tgx) && (params1.x < rScaledRadiusJ));
#else #else
if (atom1 < NUM_ATOMS && y+baseLocalAtom+j < NUM_ATOMS) { unsigned int includeInteraction = (atom1 < NUM_ATOMS && y+baseLocalAtom+j < NUM_ATOMS && (j+baseLocalAtom != tgx) && (params1.x < rScaledRadiusJ));
#endif #endif
float invR = RSQRT(r2); float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
float r = RECIP(invR); float u_ij = RECIP(rScaledRadiusJ);
float2 params2 = (float2) (localData[baseLocalAtom+j].radius, localData[baseLocalAtom+j].scaledRadius); float l_ij2 = l_ij*l_ij;
float rScaledRadiusJ = r+params2.y; float u_ij2 = u_ij*u_ij;
if ((j+baseLocalAtom != tgx) && (params1.x < rScaledRadiusJ)) { float ratio = LOG(u_ij * RECIP(l_ij));
float l_ij = RECIP(max(params1.x, fabs(r-params2.y))); bornSum += select(0.0f, l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
float u_ij = RECIP(rScaledRadiusJ); (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2), includeInteraction);
float l_ij2 = l_ij*l_ij; bornSum += select(0.0f, 2.0f*(RECIP(params1.x)-l_ij), includeInteraction && params1.x < params2.x-r);
float u_ij2 = u_ij*u_ij;
float ratio = LOG(u_ij * RECIP(l_ij));
bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
(0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2);
if (params1.x < params2.x-r)
bornSum += 2.0f*(RECIP(params1.x)-l_ij);
}
}
} }
// Sum the forces and write results. // Sum the forces and write results.
...@@ -130,38 +126,36 @@ void computeBornSum(__global float* global_bornSum, __global float4* posq, __glo ...@@ -130,38 +126,36 @@ void computeBornSum(__global float* global_bornSum, __global float4* posq, __glo
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
if (atom1 < NUM_ATOMS && y+baseLocalAtom+tj < NUM_ATOMS && r2 < CUTOFF_SQUARED) { unsigned int includeInteraction = (atom1 < NUM_ATOMS && y+baseLocalAtom+tj < NUM_ATOMS && r2 < CUTOFF_SQUARED);
#else #else
if (atom1 < NUM_ATOMS && y+baseLocalAtom+tj < NUM_ATOMS) { unsigned int includeInteraction = (atom1 < NUM_ATOMS && y+baseLocalAtom+tj < NUM_ATOMS);
#endif #endif
float invR = RSQRT(r2); float invR = RSQRT(r2);
float r = RECIP(invR); float r = RECIP(invR);
float2 params2 = (float2) (localData[baseLocalAtom+tj].radius, localData[baseLocalAtom+tj].scaledRadius); float2 params2 = (float2) (localData[baseLocalAtom+tj].radius, localData[baseLocalAtom+tj].scaledRadius);
float rScaledRadiusJ = r+params2.y; float rScaledRadiusJ = r+params2.y;
if (params1.x < rScaledRadiusJ) { {
float l_ij = RECIP(max(params1.x, fabs(r-params2.y))); float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
float u_ij = RECIP(rScaledRadiusJ); float u_ij = RECIP(rScaledRadiusJ);
float l_ij2 = l_ij*l_ij; float l_ij2 = l_ij*l_ij;
float u_ij2 = u_ij*u_ij; float u_ij2 = u_ij*u_ij;
float ratio = LOG(u_ij * RECIP(l_ij)); float ratio = LOG(u_ij * RECIP(l_ij));
bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) + unsigned int includeTerm = (includeInteraction && params1.x < rScaledRadiusJ);
(0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2); bornSum += select(0.0f, l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
if (params1.x < params2.x-r) (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2), includeTerm);
bornSum += 2.0f*(RECIP(params1.x)-l_ij); bornSum += select(0.0f, 2.0f*(RECIP(params1.x)-l_ij), includeTerm && params1.x < params2.x-r);
} }
float rScaledRadiusI = r+params1.y; float rScaledRadiusI = r+params1.y;
if (params2.x < rScaledRadiusI) { {
float l_ij = RECIP(max(params2.x, fabs(r-params1.y))); float l_ij = RECIP(max(params2.x, fabs(r-params1.y)));
float u_ij = RECIP(rScaledRadiusI); float u_ij = RECIP(rScaledRadiusI);
float l_ij2 = l_ij*l_ij; float l_ij2 = l_ij*l_ij;
float u_ij2 = u_ij*u_ij; float u_ij2 = u_ij*u_ij;
float ratio = LOG(u_ij * RECIP(l_ij)); float ratio = LOG(u_ij * RECIP(l_ij));
float term = l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) + float term = l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
(0.25f*params1.y*params1.y*invR)*(l_ij2-u_ij2); (0.25f*params1.y*params1.y*invR)*(l_ij2-u_ij2);
if (params2.x < params1.x-r) term += select(0.0f, 2.0f*(RECIP(params2.x)-l_ij), params2.x < params1.x-r);
term += 2.0f*(RECIP(params2.x)-l_ij); localData[baseLocalAtom+tj+forceBufferOffset].bornSum += select(0.0f, term, includeInteraction && params2.x < rScaledRadiusI);
localData[baseLocalAtom+tj+forceBufferOffset].bornSum += term;
}
} }
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
tj = (tj+1)%(TILE_SIZE/2); tj = (tj+1)%(TILE_SIZE/2);
...@@ -234,38 +228,35 @@ void computeGBSAForce1(__global float4* forceBuffers, __global float* energyBuff ...@@ -234,38 +228,35 @@ void computeGBSAForce1(__global float4* forceBuffers, __global float* energyBuff
unsigned int xi = x/TILE_SIZE; unsigned int xi = x/TILE_SIZE;
unsigned int tile = xi+xi*PADDED_NUM_ATOMS/TILE_SIZE-xi*(xi+1)/2; unsigned int tile = xi+xi*PADDED_NUM_ATOMS/TILE_SIZE-xi*(xi+1)/2;
for (unsigned int j = 0; j < TILE_SIZE/2; j++) { for (unsigned int j = 0; j < TILE_SIZE/2; j++) {
if (atom1 < NUM_ATOMS && y+baseLocalAtom+j < NUM_ATOMS) { unsigned int includeInteraction = (atom1 < NUM_ATOMS && y+baseLocalAtom+j < NUM_ATOMS);
float4 posq2 = (float4) (localData[baseLocalAtom+j].x, localData[baseLocalAtom+j].y, localData[baseLocalAtom+j].z, localData[baseLocalAtom+j].q); float4 posq2 = (float4) (localData[baseLocalAtom+j].x, localData[baseLocalAtom+j].y, localData[baseLocalAtom+j].z, localData[baseLocalAtom+j].q);
float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f); float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y; delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z; delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
float invR = RSQRT(r2); float invR = RSQRT(r2);
float r = RECIP(invR); float r = RECIP(invR);
float bornRadius2 = localData[baseLocalAtom+j].bornRadius; float bornRadius2 = localData[baseLocalAtom+j].bornRadius;
float alpha2_ij = bornRadius1*bornRadius2; float alpha2_ij = bornRadius1*bornRadius2;
float D_ij = r2*RECIP(4.0f*alpha2_ij); float D_ij = r2*RECIP(4.0f*alpha2_ij);
float expTerm = EXP(-D_ij); float expTerm = EXP(-D_ij);
float denominator2 = r2 + alpha2_ij*expTerm; float denominator2 = r2 + alpha2_ij*expTerm;
float denominator = SQRT(denominator2); float denominator = SQRT(denominator2);
float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator); float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
float Gpol = tempEnergy*RECIP(denominator2); float Gpol = tempEnergy*RECIP(denominator2);
float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij); float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
force.w += dGpol_dalpha2_ij*bornRadius2; force.w += select(0.0f, dGpol_dalpha2_ij*bornRadius2, includeInteraction);
float dEdR = Gpol*(1.0f - 0.25f*expTerm); float dEdR = Gpol*(1.0f - 0.25f*expTerm);
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
if (r2 > CUTOFF_SQUARED) { dEdR = select(dEdR, 0.0f, r2 > CUTOFF_SQUARED);
dEdR = 0.0f; tempEnergy = select(tempEnergy, 0.0f, r2 > CUTOFF_SQUARED);
tempEnergy = 0.0f;
}
#endif #endif
energy += 0.5f*tempEnergy; energy += select(0.0f, 0.5f*tempEnergy, includeInteraction);
delta.xyz *= dEdR; delta.xyz *= select(0.0f, dEdR, includeInteraction);
force.xyz -= delta.xyz; force.xyz -= delta.xyz;
}
} }
// Sum the forces and write results. // Sum the forces and write results.
...@@ -308,42 +299,39 @@ void computeGBSAForce1(__global float4* forceBuffers, __global float* energyBuff ...@@ -308,42 +299,39 @@ void computeGBSAForce1(__global float4* forceBuffers, __global float* energyBuff
unsigned int tile = xi+yi*PADDED_NUM_ATOMS/TILE_SIZE-yi*(yi+1)/2; unsigned int tile = xi+yi*PADDED_NUM_ATOMS/TILE_SIZE-yi*(yi+1)/2;
unsigned int tj = tgx%(TILE_SIZE/2); unsigned int tj = tgx%(TILE_SIZE/2);
for (unsigned int j = 0; j < TILE_SIZE/2; j++) { for (unsigned int j = 0; j < TILE_SIZE/2; j++) {
if (atom1 < NUM_ATOMS && y+baseLocalAtom+tj < NUM_ATOMS) { unsigned int includeInteraction = (atom1 < NUM_ATOMS && y+baseLocalAtom+tj < NUM_ATOMS);
float4 posq2 = (float4) (localData[baseLocalAtom+tj].x, localData[baseLocalAtom+tj].y, localData[baseLocalAtom+tj].z, localData[baseLocalAtom+tj].q); float4 posq2 = (float4) (localData[baseLocalAtom+tj].x, localData[baseLocalAtom+tj].y, localData[baseLocalAtom+tj].z, localData[baseLocalAtom+tj].q);
float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f); float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y; delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z; delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
float invR = RSQRT(r2); float invR = RSQRT(r2);
float r = RECIP(invR); float r = RECIP(invR);
float bornRadius2 = localData[baseLocalAtom+tj].bornRadius; float bornRadius2 = localData[baseLocalAtom+tj].bornRadius;
float alpha2_ij = bornRadius1*bornRadius2; float alpha2_ij = bornRadius1*bornRadius2;
float D_ij = r2*RECIP(4.0f*alpha2_ij); float D_ij = r2*RECIP(4.0f*alpha2_ij);
float expTerm = EXP(-D_ij); float expTerm = EXP(-D_ij);
float denominator2 = r2 + alpha2_ij*expTerm; float denominator2 = r2 + alpha2_ij*expTerm;
float denominator = SQRT(denominator2); float denominator = SQRT(denominator2);
float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator); float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
float Gpol = tempEnergy*RECIP(denominator2); float Gpol = tempEnergy*RECIP(denominator2);
float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij); float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
force.w += dGpol_dalpha2_ij*bornRadius2; force.w += select(0.0f, dGpol_dalpha2_ij*bornRadius2, includeInteraction);
float dEdR = Gpol*(1.0f - 0.25f*expTerm); float dEdR = Gpol*(1.0f - 0.25f*expTerm);
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
if (r2 > CUTOFF_SQUARED) { dEdR = select(dEdR, 0.0f, r2 > CUTOFF_SQUARED);
dEdR = 0.0f; tempEnergy = select(tempEnergy, 0.0f, r2 > CUTOFF_SQUARED);
tempEnergy = 0.0f;
}
#endif #endif
energy += tempEnergy; energy += select(0.0f, tempEnergy, includeInteraction);
delta.xyz *= dEdR; delta.xyz *= select(0.0f, dEdR, includeInteraction);
force.xyz -= delta.xyz; force.xyz -= delta.xyz;
localData[baseLocalAtom+tj+forceBufferOffset].fx += delta.x; localData[baseLocalAtom+tj+forceBufferOffset].fx += delta.x;
localData[baseLocalAtom+tj+forceBufferOffset].fy += delta.y; localData[baseLocalAtom+tj+forceBufferOffset].fy += delta.y;
localData[baseLocalAtom+tj+forceBufferOffset].fz += delta.z; localData[baseLocalAtom+tj+forceBufferOffset].fz += delta.z;
localData[baseLocalAtom+tj+forceBufferOffset].fw += dGpol_dalpha2_ij*bornRadius1; localData[baseLocalAtom+tj+forceBufferOffset].fw += select(0.0f, dGpol_dalpha2_ij*bornRadius1, includeInteraction);
}
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
tj = (tj+1)%(TILE_SIZE/2); tj = (tj+1)%(TILE_SIZE/2);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment