findInteractingBlocks.cl 5.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
const int BlockSize = 32;

/**
 * Find a bounding box for the atoms in each block.
 */
__kernel void findBlockBounds(int numAtoms, float4 periodicBoxSize, __global float4* posq, __global float4* blockCenter, __global float4* blockBoundingBox) {
    int index = get_global_id(0);
    int base = index*BlockSize;
    while (base < numAtoms) {
        float4 pos = posq[base];
#ifdef USE_PERIODIC
        pos.x -= floor(pos.x/periodicBoxSize.x)*periodicBoxSize.x;
        pos.y -= floor(pos.y/periodicBoxSize.y)*periodicBoxSize.y;
        pos.z -= floor(pos.z/periodicBoxSize.z)*periodicBoxSize.z;
        float4 firstPoint = pos;
#endif
        float4 minPos = pos;
        float4 maxPos = pos;
        int last = min(base+BlockSize, numAtoms);
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
#ifdef USE_PERIODIC
            pos.x -= floor((pos.x-firstPoint.x)/periodicBoxSize.x+0.5f)*periodicBoxSize.x;
            pos.y -= floor((pos.y-firstPoint.y)/periodicBoxSize.y+0.5f)*periodicBoxSize.y;
            pos.z -= floor((pos.z-firstPoint.z)/periodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
            minPos = min(minPos, pos);
            maxPos = max(maxPos, pos);
        }
        blockBoundingBox[index] = 0.5f*(maxPos-minPos);
        blockCenter[index] = 0.5f*(maxPos+minPos);
        index += get_global_size(0);
        base = index*BlockSize;
    }
}

/**
 * Compare the bounding boxes for each pair of blocks.  If they are sufficiently far apart,
 * mark them as non-interacting.
 */
__kernel void findBlocksWithInteractions(int numTiles, float cutoffSquared, float4 periodicBoxSize, __global unsigned int* tiles, __global float4* blockCenter,
        __global float4* blockBoundingBox, __global unsigned int* interactionFlag) {
    int index = get_global_id(0);
    while (index < numTiles) {
        // Extract cell coordinates from appropriate work unit

        unsigned int x = tiles[index];
        unsigned int y = ((x >> 2) & 0x7fff);
        x = (x >> 17);

        // Find the distance between the bounding boxes of the two cells.

        float4 delta = blockCenter[x]-blockCenter[y];
#ifdef USE_PERIODIC
        delta.x -= floor(delta.x/periodicBoxSize.x+0.5f)*periodicBoxSize.x;
        delta.y -= floor(delta.y/periodicBoxSize.y+0.5f)*periodicBoxSize.y;
        delta.z -= floor(delta.z/periodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
        float4 boxSizea = blockBoundingBox[x];
        float4 boxSizeb = blockBoundingBox[y];
        delta.x = max(0.0f, fabs(delta.x)-boxSizea.x-boxSizeb.x);
        delta.y = max(0.0f, fabs(delta.y)-boxSizea.y-boxSizeb.y);
        delta.z = max(0.0f, fabs(delta.z)-boxSizea.z-boxSizeb.z);
        interactionFlag[index] = (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z > cutoffSquared ? 0 : 1);
        index += get_global_size(0);
    }
}

/**
 * Compare each atom in one block to the bounding box of another block, and set
 * flags for which ones are interacting.
 */
__kernel void findInteractionsWithinBlocks(float cutoffSquared, float4 periodicBoxSize, __global float4* posq, __global unsigned int* tiles, __global float4* blockCenter,
            __global float4* blockBoundingBox, __global unsigned int* interactionFlags, __global unsigned int* interactionCount, __local unsigned int* flags) {
    unsigned int totalWarps = get_global_size(0)/BlockSize;
    unsigned int warp = get_global_id(0)/BlockSize;
    unsigned int numTiles = interactionCount[0];
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
    unsigned int index = get_local_id(0) & (BlockSize - 1);

    unsigned int lasty = 0xFFFFFFFF;
    float4 apos;
    while (pos < end)
    {
        // Extract cell coordinates from appropriate work unit
        unsigned int x = tiles[pos];
        unsigned int y = ((x >> 2) & 0x7fff);
        bool bExclusionFlag = (x & 0x1);
        x = (x >> 17);
        if (x == y || bExclusionFlag)
        {
            // Assume this block will be dense.

            if (index == 0)
                interactionFlags[pos] = 0xFFFFFFFF;
        }
        else
        {
            // Load the bounding box for x and the atom positions for y.

            float4 center = blockCenter[x];
            float4 boxSize = blockBoundingBox[x];
            if (y != lasty)
                apos = posq[(y*BlockSize)+index];

            // Find the distance of the atom from the bounding box.

            float4 delta = apos-center;
#ifdef USE_PERIODIC
            delta.x -= floor(delta.x/periodicBoxSize.x+0.5f)*periodicBoxSize.x;
            delta.y -= floor(delta.y/periodicBoxSize.y+0.5f)*periodicBoxSize.y;
            delta.z -= floor(delta.z/periodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
            delta = max((float4) 0.0f, fabs(delta)-boxSize);
            int thread = get_local_id(0);
            flags[thread] = (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z > cutoffSquared ? 0 : 1 << index);

            // Sum the flags.

            if (index % 2 == 0)
                flags[thread] += flags[thread+1];
            if (index % 4 == 0)
                flags[thread] += flags[thread+2];
            if (index % 8 == 0)
                flags[thread] += flags[thread+4];
            if (index % 16 == 0)
                flags[thread] += flags[thread+8];
            if (index == 0)
            {
                unsigned int allFlags = flags[thread] + flags[thread+16];

                // Count how many flags are set, and based on that decide whether to compute all interactions
                // or only a fraction of them.

                unsigned int bits = (allFlags&0x55555555) + ((allFlags>>1)&0x55555555);
                bits = (bits&0x33333333) + ((bits>>2)&0x33333333);
                bits = (bits&0x0F0F0F0F) + ((bits>>4)&0x0F0F0F0F);
                bits = (bits&0x00FF00FF) + ((bits>>8)&0x00FF00FF);
                bits = (bits&0x0000FFFF) + ((bits>>16)&0x0000FFFF);
                interactionFlags[pos] = (bits > 12 ? 0xFFFFFFFF : allFlags);
            }
            lasty = y;
        }
        pos++;
    }
}