Commit 91f3379b authored by Peter Eastman's avatar Peter Eastman
Browse files

Minor optimizations to GBSA

parent d2a5b3bb
{ {
float invRSquared = RECIP(r2); float invRSquaredOver4 = 0.25f*invR*invR;
float rScaledRadiusJ = r+obcParams2.y; float rScaledRadiusJ = r+obcParams2.y;
float rScaledRadiusI = r+obcParams1.y; float rScaledRadiusI = r+obcParams1.y;
float l_ijJ = RECIP(max(obcParams1.x, fabs(r-obcParams2.y))); float l_ijJ = RECIP(max(obcParams1.x, fabs(r-obcParams2.y)));
...@@ -14,12 +14,8 @@ ...@@ -14,12 +14,8 @@
float t1I = LOG(u_ijI*RECIP(l_ijI)); float t1I = LOG(u_ijI*RECIP(l_ijI));
float t2J = (l_ij2J-u_ij2J); float t2J = (l_ij2J-u_ij2J);
float t2I = (l_ij2I-u_ij2I); float t2I = (l_ij2I-u_ij2I);
float t3J = t2J*invR; float term1 = (0.5f*(0.25f+obcParams2.y*obcParams2.y*invRSquaredOver4)*t2J + t1J*invRSquaredOver4)*invR;
float t3I = t2I*invR; float term2 = (0.5f*(0.25f+obcParams1.y*obcParams1.y*invRSquaredOver4)*t2I + t1I*invRSquaredOver4)*invR;
t1J *= invR;
t1I *= invR;
float term1 = 0.125f*(1.0f+obcParams2.y*obcParams2.y*invRSquared)*t3J + 0.25f*t1J*invRSquared;
float term2 = 0.125f*(1.0f+obcParams1.y*obcParams1.y*invRSquared)*t3I + 0.25f*t1I*invRSquared;
float tempdEdR = select(0.0f, bornForce1*term1, obcParams1.x < rScaledRadiusJ); float tempdEdR = select(0.0f, bornForce1*term1, obcParams1.x < rScaledRadiusJ);
tempdEdR += select(0.0f, bornForce2*term2, obcParams2.x < rScaledRadiusJ); tempdEdR += select(0.0f, bornForce2*term2, obcParams2.x < rScaledRadiusJ);
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
......
...@@ -7,12 +7,9 @@ ...@@ -7,12 +7,9 @@
typedef struct { typedef struct {
float x, y, z; float x, y, z;
float q; float q;
float fx, fy, fz, fw;
float radius, scaledRadius; float radius, scaledRadius;
float bornSum; float bornSum;
float bornRadius; } AtomData1;
float bornForce;
} AtomData;
/** /**
* Compute the Born sum. * Compute the Born sum.
...@@ -24,7 +21,7 @@ __kernel void computeBornSum( ...@@ -24,7 +21,7 @@ __kernel void computeBornSum(
__global float* global_bornSum, __global float* global_bornSum,
#endif #endif
__global float4* posq, __global float2* global_params, __global float4* posq, __global float2* global_params,
__local AtomData* localData, __local float* tempBuffer, __local AtomData1* localData, __local float* tempBuffer,
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
__global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) { __global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) {
#else #else
...@@ -104,8 +101,8 @@ __kernel void computeBornSum( ...@@ -104,8 +101,8 @@ __kernel void computeBornSum(
float l_ij2 = l_ij*l_ij; float l_ij2 = l_ij*l_ij;
float u_ij2 = u_ij*u_ij; float u_ij2 = u_ij*u_ij;
float ratio = LOG(u_ij * RECIP(l_ij)); float ratio = LOG(u_ij * RECIP(l_ij));
bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) + bornSum += l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
(0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2); (params2.y*params2.y*invR)*(l_ij2-u_ij2));
if (params1.x < params2.x-r) if (params1.x < params2.x-r)
bornSum += 2.0f*(RECIP(params1.x)-l_ij); bornSum += 2.0f*(RECIP(params1.x)-l_ij);
} }
...@@ -161,8 +158,8 @@ __kernel void computeBornSum( ...@@ -161,8 +158,8 @@ __kernel void computeBornSum(
float l_ij2 = l_ij*l_ij; float l_ij2 = l_ij*l_ij;
float u_ij2 = u_ij*u_ij; float u_ij2 = u_ij*u_ij;
float ratio = LOG(u_ij * RECIP(l_ij)); float ratio = LOG(u_ij * RECIP(l_ij));
bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) + bornSum += l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
(0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2); (params2.y*params2.y*invR)*(l_ij2-u_ij2));
if (params1.x < params2.x-r) if (params1.x < params2.x-r)
bornSum += 2.0f*(RECIP(params1.x)-l_ij); bornSum += 2.0f*(RECIP(params1.x)-l_ij);
} }
...@@ -173,8 +170,8 @@ __kernel void computeBornSum( ...@@ -173,8 +170,8 @@ __kernel void computeBornSum(
float l_ij2 = l_ij*l_ij; float l_ij2 = l_ij*l_ij;
float u_ij2 = u_ij*u_ij; float u_ij2 = u_ij*u_ij;
float ratio = LOG(u_ij * RECIP(l_ij)); float ratio = LOG(u_ij * RECIP(l_ij));
float term = l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) + float term = l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
(0.25f*params1.y*params1.y*invR)*(l_ij2-u_ij2); (params1.y*params1.y*invR)*(l_ij2-u_ij2));
if (params2.x < params1.x-r) if (params2.x < params1.x-r)
term += 2.0f*(RECIP(params2.x)-l_ij); term += 2.0f*(RECIP(params2.x)-l_ij);
tempBuffer[get_local_id(0)] = term; tempBuffer[get_local_id(0)] = term;
...@@ -220,8 +217,8 @@ __kernel void computeBornSum( ...@@ -220,8 +217,8 @@ __kernel void computeBornSum(
float l_ij2 = l_ij*l_ij; float l_ij2 = l_ij*l_ij;
float u_ij2 = u_ij*u_ij; float u_ij2 = u_ij*u_ij;
float ratio = LOG(u_ij * RECIP(l_ij)); float ratio = LOG(u_ij * RECIP(l_ij));
bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) + bornSum += l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
(0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2); (params2.y*params2.y*invR)*(l_ij2-u_ij2));
if (params1.x < params2.x-r) if (params1.x < params2.x-r)
bornSum += 2.0f*(RECIP(params1.x)-l_ij); bornSum += 2.0f*(RECIP(params1.x)-l_ij);
} }
...@@ -232,8 +229,8 @@ __kernel void computeBornSum( ...@@ -232,8 +229,8 @@ __kernel void computeBornSum(
float l_ij2 = l_ij*l_ij; float l_ij2 = l_ij*l_ij;
float u_ij2 = u_ij*u_ij; float u_ij2 = u_ij*u_ij;
float ratio = LOG(u_ij * RECIP(l_ij)); float ratio = LOG(u_ij * RECIP(l_ij));
float term = l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) + float term = l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
(0.25f*params1.y*params1.y*invR)*(l_ij2-u_ij2); (params1.y*params1.y*invR)*(l_ij2-u_ij2));
if (params2.x < params1.x-r) if (params2.x < params1.x-r)
term += 2.0f*(RECIP(params2.x)-l_ij); term += 2.0f*(RECIP(params2.x)-l_ij);
localData[tbx+tj].bornSum += term; localData[tbx+tj].bornSum += term;
...@@ -313,6 +310,13 @@ __kernel void computeBornSum( ...@@ -313,6 +310,13 @@ __kernel void computeBornSum(
} while (pos < end); } while (pos < end);
} }
typedef struct {
float x, y, z;
float q;
float fx, fy, fz, fw;
float bornRadius;
} AtomData2;
/** /**
* First part of computing the GBSA interaction. * First part of computing the GBSA interaction.
*/ */
...@@ -324,7 +328,7 @@ __kernel void computeGBSAForce1( ...@@ -324,7 +328,7 @@ __kernel void computeGBSAForce1(
__global float4* forceBuffers, __global float* global_bornForce, __global float4* forceBuffers, __global float* global_bornForce,
#endif #endif
__global float* energyBuffer, __global float4* posq, __global float* global_bornRadii, __global float* energyBuffer, __global float4* posq, __global float* global_bornRadii,
__local AtomData* localData, __local float4* tempBuffer, __local AtomData2* localData, __local float4* tempBuffer,
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
__global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) { __global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) {
#else #else
......
...@@ -159,7 +159,7 @@ __kernel void computeNonbonded( ...@@ -159,7 +159,7 @@ __kernel void computeNonbonded(
else { else {
// Compute only a subset of the interactions in this tile. // Compute only a subset of the interactions in this tile.
for (unsigned int j = 0; j < TILE_SIZE; j++) { for (j = 0; j < TILE_SIZE; j++) {
if ((flags&(1<<j)) != 0) { if ((flags&(1<<j)) != 0) {
bool isExcluded = false; bool isExcluded = false;
int atom2 = tbx+j; int atom2 = tbx+j;
...@@ -230,7 +230,7 @@ __kernel void computeNonbonded( ...@@ -230,7 +230,7 @@ __kernel void computeNonbonded(
excl = (excl >> tgx) | (excl << (TILE_SIZE - tgx)); excl = (excl >> tgx) | (excl << (TILE_SIZE - tgx));
#endif #endif
unsigned int tj = tgx; unsigned int tj = tgx;
for (unsigned int j = 0; j < TILE_SIZE; j++) { for (j = 0; j < TILE_SIZE; j++) {
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
bool isExcluded = !(excl & 0x1); bool isExcluded = !(excl & 0x1);
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment