findInteractingBlocks_cpu.cl 6.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
#define TILE_SIZE 32
#define GROUP_SIZE 64
#define BUFFER_GROUPS 4
#define BUFFER_SIZE BUFFER_GROUPS*GROUP_SIZE

/**
 * Find a bounding box for the atoms in each block.
 */
__kernel void findBlockBounds(int numAtoms, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* posq, __global float4* blockCenter, __global float4* blockBoundingBox, __global unsigned int* interactionCount) {
    int index = get_global_id(0);
    int base = index*TILE_SIZE;
    while (base < numAtoms) {
        float4 pos = posq[base];
#ifdef USE_PERIODIC
        pos.x -= floor(pos.x*invPeriodicBoxSize.x)*periodicBoxSize.x;
        pos.y -= floor(pos.y*invPeriodicBoxSize.y)*periodicBoxSize.y;
        pos.z -= floor(pos.z*invPeriodicBoxSize.z)*periodicBoxSize.z;
        float4 firstPoint = pos;
#endif
        float4 minPos = pos;
        float4 maxPos = pos;
        int last = min(base+TILE_SIZE, numAtoms);
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
#ifdef USE_PERIODIC
            pos.x -= floor((pos.x-firstPoint.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            pos.y -= floor((pos.y-firstPoint.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            pos.z -= floor((pos.z-firstPoint.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
            minPos = min(minPos, pos);
            maxPos = max(maxPos, pos);
        }
        blockBoundingBox[index] = 0.5f*(maxPos-minPos);
        blockCenter[index] = 0.5f*(maxPos+minPos);
        index += get_global_size(0);
        base = index*TILE_SIZE;
    }
    if (get_global_id(0) == 0)
        interactionCount[0] = 0;
}

/**
 * This is called by findBlocksWithInteractions().  It compacts the list of blocks and writes them
 * to global memory.
 */
void storeInteractionData(__local ushort2* buffer, int numValid, __local unsigned int* flagsBuffer, __local float4* temp,
            __global unsigned int* interactionCount, __global ushort2* interactingTiles, __global unsigned int* interactionFlags, float cutoffSquared, float4 periodicBoxSize,
            float4 invPeriodicBoxSize, __global float4* posq, __global float4* blockCenter, __global float4* blockBoundingBox, unsigned int maxTiles) {
    // Filter the list of tiles by comparing the distance from each atom to the other bounding box.

    int lasty = -1;
    for (int tile = 0; tile < numValid; ) {
        int x = buffer[tile].x;
        int y = buffer[tile].y;
        if (x == y) {
            tile++;
            continue;
        }

        // Load the atom positions and the bounding box of the other block.

        float4 center = blockCenter[x];
        float4 boxSize = blockBoundingBox[x];
        if (y != lasty)
            for (int atom = 0; atom < TILE_SIZE; atom++)
                temp[atom] = posq[y*TILE_SIZE+atom];
        lasty = y;

        // Find the distance of each atom from the bounding box.

        unsigned int flags = 0;
        for (int atom = 0; atom < TILE_SIZE; atom++) {
            float4 delta = temp[atom]-center;
#ifdef USE_PERIODIC
            delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
            delta = max((float4) 0.0f, fabs(delta)-boxSize);
            if (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < cutoffSquared)
                flags += 1 << atom;
        }
        if (flags == 0) {
            // This tile contains no interactions.

            numValid--;
            buffer[tile] = buffer[numValid];
        }
        else {
            flagsBuffer[tile] = flags;
            tile++;
        }
    }

    // Store it to global memory.

    int baseIndex = atom_add(interactionCount, numValid);
    if (baseIndex+numValid <= maxTiles)
        for (int i = 0; i < numValid; i++) {
            interactingTiles[baseIndex+i] = buffer[i];
            interactionFlags[baseIndex+i] = flagsBuffer[i];
        }
}

/**
 * Compare the bounding boxes for each pair of blocks.  If they are sufficiently far apart,
 * mark them as non-interacting.
 */
__kernel void findBlocksWithInteractions(float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* blockCenter,
        __global float4* blockBoundingBox, __global unsigned int* interactionCount, __global ushort2* interactingTiles,
        __global unsigned int* interactionFlags, __global float4* posq, unsigned int maxTiles) {
    __local ushort2 buffer[BUFFER_SIZE];
    __local unsigned int flagsBuffer[BUFFER_SIZE];
    __local float4 temp[TILE_SIZE];
    int valuesInBuffer = 0;
    const int numTiles = (NUM_BLOCKS*(NUM_BLOCKS+1))/2;
    unsigned int start = get_group_id(0)*numTiles/get_num_groups(0);
    unsigned int end = (get_group_id(0)+1)*numTiles/get_num_groups(0);
    for (int index = start; index < end; index++) {
        // Identify the pair of blocks to compare.

        unsigned int y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*index));
        unsigned int x = (index-y*NUM_BLOCKS+y*(y+1)/2);
        if (x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
            y++;
            x = (index-y*NUM_BLOCKS+y*(y+1)/2);
        }

        // Find the distance between the bounding boxes of the two cells.

        float4 delta = blockCenter[x]-blockCenter[y];
#ifdef USE_PERIODIC
        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
        float4 boxSizea = blockBoundingBox[x];
        float4 boxSizeb = blockBoundingBox[y];
        delta.x = max(0.0f, fabs(delta.x)-boxSizea.x-boxSizeb.x);
        delta.y = max(0.0f, fabs(delta.y)-boxSizea.y-boxSizeb.y);
        delta.z = max(0.0f, fabs(delta.z)-boxSizea.z-boxSizeb.z);
        if (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < cutoffSquared) {
            // Add this tile to the buffer.

            buffer[valuesInBuffer++] = (ushort2) (x, y);
            if (valuesInBuffer == BUFFER_SIZE) {
                storeInteractionData(buffer, valuesInBuffer, flagsBuffer, temp, interactionCount, interactingTiles, interactionFlags, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles);
                valuesInBuffer = 0;
            }
        }
    }
    storeInteractionData(buffer, valuesInBuffer, flagsBuffer, temp, interactionCount, interactingTiles, interactionFlags, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles);
}