OpenCLNonbondedUtilities.cpp 31.9 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2013 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/OpenMMException.h"
28
29
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
30
#include "OpenCLKernelSources.h"
31
#include "OpenCLExpressionUtilities.h"
32
33
#include "OpenCLSort.h"
#include <algorithm>
34
#include <map>
35
36
#include <set>
#include <utility>
37
38
39
40

using namespace OpenMM;
using namespace std;

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class OpenCLNonbondedUtilities::BlockSortTrait : public OpenCLSort::SortTrait {
public:
    BlockSortTrait(bool useDouble) : useDouble(useDouble) {
    }
    int getDataSize() const {return useDouble ? sizeof(mm_double2) : sizeof(mm_float2);}
    int getKeySize() const {return useDouble ? sizeof(cl_double) : sizeof(cl_float);}
    const char* getDataType() const {return "real2";}
    const char* getKeyType() const {return "real";}
    const char* getMinKey() const {return "-MAXFLOAT";}
    const char* getMaxKey() const {return "MAXFLOAT";}
    const char* getMaxValue() const {return "(real2) (MAXFLOAT, MAXFLOAT)";}
    const char* getSortKey() const {return "value.x";}
private:
    bool useDouble;
};

OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), cutoff(-1.0), useCutoff(false), anyExclusions(false), usePadding(true),
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusionTiles(NULL), exclusions(NULL), interactingTiles(NULL), interactingAtoms(NULL),
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), sortedBlocks(NULL), sortedBlockCenter(NULL), sortedBlockBoundingBox(NULL),
        oldPositions(NULL), rebuildNeighborList(NULL), blockSorter(NULL), nonbondedForceGroup(0) {
61
    // Decide how many thread blocks and force buffers to use.
62

63
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
64
65
66
67
68
69
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
70
        if (context.getSupports64BitGlobalAtomics()) {
71
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
72
            forceThreadBlockSize = 256;
73
74
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
75
76
        }
        else {
77
78
79
            numForceThreadBlocks = 3*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 256;
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
80
        }
81
    }
82
    else {
83
        numForceThreadBlocks = context.getNumThreadBlocks();
84
        forceThreadBlockSize = (context.getSIMDWidth() >= 32 ? OpenCLContext::ThreadBlockSize : 32);
85
86
87
88
89
        if (context.getSupports64BitGlobalAtomics()) {
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
        }
        else {
90
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
91
        }
92
    }
93
94
95
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
96
97
98
99
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
100
101
    if (exclusionTiles != NULL)
        delete exclusionTiles;
102
103
    if (exclusions != NULL)
        delete exclusions;
104
105
    if (interactingTiles != NULL)
        delete interactingTiles;
106
107
    if (interactingAtoms != NULL)
        delete interactingAtoms;
108
109
110
111
112
113
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
114
115
116
117
118
119
120
121
122
123
124
125
    if (sortedBlocks != NULL)
        delete sortedBlocks;
    if (sortedBlockCenter != NULL)
        delete sortedBlockCenter;
    if (sortedBlockBoundingBox != NULL)
        delete sortedBlockBoundingBox;
    if (oldPositions != NULL)
        delete oldPositions;
    if (rebuildNeighborList != NULL)
        delete rebuildNeighborList;
    if (blockSorter != NULL)
        delete blockSorter;
126
127
}

128
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
129
130
131
132
133
134
135
    if (cutoff != -1.0) {
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
        if (cutoffDistance != cutoff)
            throw OpenMMException("All Forces must use the same cutoff distance");
136
137
        if (forceGroup != nonbondedForceGroup)
            throw OpenMMException("All nonbonded forces must be in the same force group");
138
    }
139
140
141
142
143
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
    cutoff = cutoffDistance;
Peter Eastman's avatar
Peter Eastman committed
144
145
    if (kernel.size() > 0)
        kernelSource += kernel+"\n";
146
    nonbondedForceGroup = forceGroup;
147
148
149
150
151
152
153
154
155
156
157
158
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
159
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
160
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
161
162
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
163
164
            set<int> expectedExclusions;
            expectedExclusions.insert(atomExclusions[i].begin(), atomExclusions[i].end());
165
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
166
                if (expectedExclusions.find(exclusionList[i][j]) == expectedExclusions.end())
167
168
169
170
171
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
172
    else {
173
        atomExclusions = exclusionList;
174
175
        anyExclusions = true;
    }
176
177
}

178
179
180
181
static bool compareUshort2(mm_ushort2 a, mm_ushort2 b) {
    return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}

182
void OpenCLNonbondedUtilities::initialize(const System& system) {
183
184
185
186
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
        
        atomExclusions.resize(context.getNumAtoms());
187
        for (int i = 0; i < (int) atomExclusions.size(); i++)
188
189
190
            atomExclusions[i].push_back(i);
    }

191
192
193
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
194
    int numContexts = context.getPlatformData().contexts.size();
195
    setAtomBlockRange(context.getContextIndex()/(double) numContexts, (context.getContextIndex()+1)/(double) numContexts);
196

197
198
    // Build a list of tiles that contain exclusions.
    
199
200
201
202
203
204
205
206
207
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    vector<mm_ushort2> exclusionTilesVec;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter)
        exclusionTilesVec.push_back(mm_ushort2((unsigned short) iter->first, (unsigned short) iter->second));
    sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), compareUshort2);
    exclusionTiles = OpenCLArray::create<mm_ushort2>(context, exclusionTilesVec.size(), "exclusionTiles");
    exclusionTiles->upload(exclusionTilesVec);
    map<pair<int, int>, int> exclusionTileMap;
    for (int i = 0; i < (int) exclusionTilesVec.size(); i++) {
        mm_ushort2 tile = exclusionTilesVec[i];
        exclusionTileMap[make_pair(tile.x, tile.y)] = i;
    }
    vector<vector<int> > exclusionBlocksForBlock(numAtomBlocks);
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        exclusionBlocksForBlock[iter->first].push_back(iter->second);
        if (iter->first != iter->second)
            exclusionBlocksForBlock[iter->second].push_back(iter->first);
224
225
226
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
227
228
229
    for (int i = 0; i < numAtomBlocks; i++) {
        exclusionIndicesVec.insert(exclusionIndicesVec.end(), exclusionBlocksForBlock[i].begin(), exclusionBlocksForBlock[i].end());
        exclusionRowIndicesVec[i+1] = exclusionIndicesVec.size();
230
    }
231
232
    exclusionIndices = OpenCLArray::create<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = OpenCLArray::create<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
233
234
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
235
236
237

    // Record the exclusion data.

238
    exclusions = OpenCLArray::create<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
239
240
    cl_uint allFlags = (cl_uint) -1;
    vector<cl_uint> exclusionVec(exclusions->getSize(), allFlags);
241
242
243
244
245
246
247
248
249
250
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
251
252
                int index = exclusionTileMap[make_pair(x, y)]*OpenCLContext::TileSize;
                exclusionVec[index+offset1] &= allFlags-(1<<offset2);
253
254
            }
            else {
255
256
                int index = exclusionTileMap[make_pair(y, x)]*OpenCLContext::TileSize;
                exclusionVec[index+offset2] &= allFlags-(1<<offset1);
257
258
259
260
261
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
262
263
264
265

    // Create data structures for the neighbor list.

    if (useCutoff) {
266
267
268
269
270
271
272
273
274
        // Select a size for the arrays that hold the neighbor list.  We have to make a fairly
        // arbitrary guess, but if this turns out to be too small we'll increase it later.

        int maxTiles = 20*numAtomBlocks;
        if (maxTiles > numTiles)
            maxTiles = numTiles;
        if (maxTiles < 1)
            maxTiles = 1;
        int numAtoms = context.getNumAtoms();
275
        interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
276
        interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
277
        interactionCount = OpenCLArray::create<cl_uint>(context, 1, "interactionCount");
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        int elementSize = (context.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
        blockCenter = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockCenter");
        blockBoundingBox = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockBoundingBox");
        sortedBlocks = new OpenCLArray(context, numAtomBlocks, 2*elementSize, "sortedBlocks");
        sortedBlockCenter = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockCenter");
        sortedBlockBoundingBox = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockBoundingBox");
        oldPositions = new OpenCLArray(context, numAtoms, 4*elementSize, "oldPositions");
        if (context.getUseDoublePrecision()) {
            vector<mm_double4> oldPositionsVec(numAtoms, mm_double4(1e30, 1e30, 1e30, 0));
            oldPositions->upload(oldPositionsVec);
        }
        else {
            vector<mm_float4> oldPositionsVec(numAtoms, mm_float4(1e30f, 1e30f, 1e30f, 0));
            oldPositions->upload(oldPositionsVec);
        }
        rebuildNeighborList = OpenCLArray::create<int>(context, 1, "rebuildNeighborList");
        blockSorter = new OpenCLSort(context, new BlockSortTrait(context.getUseDoublePrecision()), numAtomBlocks);
295
296
        vector<cl_uint> count(1, 0);
        interactionCount->upload(count);
297
    }
298
299
300

    // Create kernels.

Peter Eastman's avatar
Peter Eastman committed
301
302
    if (kernelSource.size() > 0)
        forceKernel = createInteractionKernel(kernelSource, parameters, arguments, true, true);
303
    if (useCutoff) {
304
305
        double padding = (usePadding ? 0.1*cutoff : 0.0);
        double paddedCutoff = cutoff+padding;
306
        map<string, string> defines;
307
308
309
310
311
312
        defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
        defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
        defines["PADDING"] = context.doubleToString(padding);
        defines["PADDED_CUTOFF"] = context.doubleToString(paddedCutoff);
        defines["PADDED_CUTOFF_SQUARED"] = context.doubleToString(paddedCutoff*paddedCutoff);
        defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(exclusionTiles->getSize());
313
        defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
314
315
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
316
317
318
319
320
        int maxExclusions = 0;
        for (int i = 0; i < (int) exclusionBlocksForBlock.size(); i++)
            maxExclusions = (maxExclusions > exclusionBlocksForBlock[i].size() ? maxExclusions : exclusionBlocksForBlock[i].size());
        defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
        defines["BUFFER_GROUPS"] = (deviceIsCpu ? "4" : "2");
321
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
322
        int groupSize = (deviceIsCpu || context.getSIMDWidth() < 32 ? 32 : 256);
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        while (true) {
            defines["GROUP_SIZE"] = context.intToString(groupSize);
            cl::Program interactingBlocksProgram = context.createProgram(file, defines);
            findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
            findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
            findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
            findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer());
            findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer());
            findBlockBoundsKernel.setArg<cl::Buffer>(6, rebuildNeighborList->getDeviceBuffer());
            findBlockBoundsKernel.setArg<cl::Buffer>(7, sortedBlocks->getDeviceBuffer());
            sortBoxDataKernel = cl::Kernel(interactingBlocksProgram, "sortBoxData");
            sortBoxDataKernel.setArg<cl::Buffer>(0, sortedBlocks->getDeviceBuffer());
            sortBoxDataKernel.setArg<cl::Buffer>(1, blockCenter->getDeviceBuffer());
            sortBoxDataKernel.setArg<cl::Buffer>(2, blockBoundingBox->getDeviceBuffer());
            sortBoxDataKernel.setArg<cl::Buffer>(3, sortedBlockCenter->getDeviceBuffer());
            sortBoxDataKernel.setArg<cl::Buffer>(4, sortedBlockBoundingBox->getDeviceBuffer());
            sortBoxDataKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer());
            sortBoxDataKernel.setArg<cl::Buffer>(6, oldPositions->getDeviceBuffer());
            sortBoxDataKernel.setArg<cl::Buffer>(7, interactionCount->getDeviceBuffer());
            sortBoxDataKernel.setArg<cl::Buffer>(8, rebuildNeighborList->getDeviceBuffer());
            findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
            findInteractingBlocksKernel.setArg<cl::Buffer>(2, interactionCount->getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl::Buffer>(3, interactingTiles->getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl::Buffer>(4, interactingAtoms->getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl_uint>(6, interactingTiles->getSize());
            findInteractingBlocksKernel.setArg<cl_uint>(7, startBlockIndex);
            findInteractingBlocksKernel.setArg<cl_uint>(8, numBlocks);
            findInteractingBlocksKernel.setArg<cl::Buffer>(9, sortedBlocks->getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl::Buffer>(10, sortedBlockCenter->getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl::Buffer>(11, sortedBlockBoundingBox->getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl::Buffer>(12, exclusionIndices->getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl::Buffer>(13, exclusionRowIndices->getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl::Buffer>(14, oldPositions->getDeviceBuffer());
            findInteractingBlocksKernel.setArg<cl::Buffer>(15, rebuildNeighborList->getDeviceBuffer());
            if (findInteractingBlocksKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()) < groupSize) {
                // The device can't handle this block size, so reduce it.
                
                groupSize -= 32;
                if (groupSize < 32)
                    throw OpenMMException("Failed to create findInteractingBlocks kernel");
                continue;
            }
            break;
        }
368
        interactingBlocksThreadBlockSize = (deviceIsCpu ? 1 : groupSize);
369
    }
370
371
}

372
373
374
375
376
377
378
379
380
381
382
383
384
385
static void setPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxSize());
}

static void setInvPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getInvPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getInvPeriodicBoxSize());
}

386
387
388
void OpenCLNonbondedUtilities::prepareInteractions() {
    if (!useCutoff)
        return;
389
390
    if (numTiles == 0)
        return;
391
392
    if (usePeriodic) {
        mm_float4 box = context.getPeriodicBoxSize();
393
        double minAllowedSize = 1.999999*cutoff;
394
395
396
        if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
    }
397
398
399

    // Compute the neighbor list.

400
401
    setPeriodicBoxSizeArg(context, findBlockBoundsKernel, 1);
    setInvPeriodicBoxSizeArg(context, findBlockBoundsKernel, 2);
402
    context.executeKernel(findBlockBoundsKernel, context.getNumAtoms());
403
404
405
406
    blockSorter->sort(*sortedBlocks);
    context.executeKernel(sortBoxDataKernel, context.getNumAtoms());
    setPeriodicBoxSizeArg(context, findInteractingBlocksKernel, 0);
    setInvPeriodicBoxSizeArg(context, findInteractingBlocksKernel, 1);
407
    context.executeKernel(findInteractingBlocksKernel, context.getNumAtoms(), interactingBlocksThreadBlockSize);
408
409
410
}

void OpenCLNonbondedUtilities::computeInteractions() {
Peter Eastman's avatar
Peter Eastman committed
411
    if (kernelSource.size() > 0) {
412
        if (useCutoff) {
413
414
            setPeriodicBoxSizeArg(context, forceKernel, 9);
            setInvPeriodicBoxSizeArg(context, forceKernel, 10);
415
        }
416
        context.executeKernel(forceKernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
417
418
        if (context.getComputeForceCount() == 1)
            updateNeighborListSize(); // This is the first time step, so check whether our initial guess was large enough.
419
    }
420
421
}

422
423
424
void OpenCLNonbondedUtilities::updateNeighborListSize() {
    if (!useCutoff)
        return;
425
426
427
    unsigned int* pinnedInteractionCount = (unsigned int*) context.getPinnedBuffer();
    interactionCount->download(pinnedInteractionCount);
    if (pinnedInteractionCount[0] <= (unsigned int) interactingTiles->getSize())
428
429
430
431
432
        return;

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

433
434
435
436
    int maxTiles = (int) (1.2*pinnedInteractionCount[0]);
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (maxTiles > totalTiles)
        maxTiles = totalTiles;
437
    delete interactingTiles;
438
439
440
    delete interactingAtoms;
    interactingTiles = NULL; // Avoid an error in the destructor if the following allocation fails
    interactingAtoms = NULL;
441
    interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
442
443
444
    interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
    forceKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
    forceKernel.setArg<cl_uint>(11, maxTiles);
445
    forceKernel.setArg<cl::Buffer>(14, interactingAtoms->getDeviceBuffer());
446
447
448
    findInteractingBlocksKernel.setArg<cl::Buffer>(3, interactingTiles->getDeviceBuffer());
    findInteractingBlocksKernel.setArg<cl::Buffer>(4, interactingAtoms->getDeviceBuffer());
    findInteractingBlocksKernel.setArg<cl_uint>(6, maxTiles);
449
450
451
452
453
454
455
456
    int numAtoms = context.getNumAtoms();
    if (context.getUseDoublePrecision()) {
        vector<mm_double4> oldPositionsVec(numAtoms, mm_double4(1e30, 1e30, 1e30, 0));
        oldPositions->upload(oldPositionsVec);
    }
    else {
        vector<mm_float4> oldPositionsVec(numAtoms, mm_float4(1e30f, 1e30f, 1e30f, 0));
        oldPositions->upload(oldPositionsVec);
457
458
459
    }
}

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
void OpenCLNonbondedUtilities::setUsePadding(bool padding) {
    usePadding = padding;
}

void OpenCLNonbondedUtilities::setAtomBlockRange(double startFraction, double endFraction) {
    int numAtomBlocks = context.getNumAtomBlocks();
    startBlockIndex = (int) (startFraction*numAtomBlocks);
    numBlocks = (int) (endFraction*numAtomBlocks)-startBlockIndex;
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    startTileIndex = (int) (startFraction*totalTiles);;
    numTiles = (int) (endFraction*totalTiles)-startTileIndex;
    if (useCutoff && interactingTiles != NULL) {
        // We are using a cutoff, and the kernels have already been created.
        
        forceKernel.setArg<cl_uint>(5, startTileIndex);
        forceKernel.setArg<cl_uint>(6, numTiles);
476
477
        findInteractingBlocksKernel.setArg<cl_uint>(7, startBlockIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(8, numBlocks);
478
479
480
    }
}

481
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) const {
482
483
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
484
    const string suffixes[] = {"x", "y", "z", "w"};
485
    stringstream localData;
486
    int localDataSize = 0;
487
    for (int i = 0; i < (int) params.size(); i++) {
488
489
490
491
492
493
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
494
495
496
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
497
    stringstream args;
498
    for (int i = 0; i < (int) params.size(); i++) {
499
        args << ", __global const ";
500
        args << params[i].getType();
501
        args << "* restrict global_";
502
503
        args << params[i].getName();
    }
504
    for (int i = 0; i < (int) arguments.size(); i++) {
505
506
507
508
509
510
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
511
                args << ", __global const ";
512
513
514
            else
                args << ", __constant ";
            args << arguments[i].getType();
515
            args << "* restrict ";
516
517
            args << arguments[i].getName();
        }
518
    }
519
520
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
521
    for (int i = 0; i < (int) params.size(); i++) {
522
        if (params[i].getNumComponents() == 1) {
523
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
524
525
526
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
527
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
528
        }
529
530
531
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
532
    for (int i = 0; i < (int) params.size(); i++) {
533
        if (params[i].getNumComponents() == 1) {
534
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
535
536
537
538
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
539
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
540
        }
541
542
543
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
544
    for (int i = 0; i < (int) params.size(); i++) {
545
546
547
548
549
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
550
        load1 << "[atom1];\n";
551
552
553
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
554
    for (int i = 0; i < (int) params.size(); i++) {
555
        if (params[i].getNumComponents() == 1) {
556
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
557
558
559
560
561
562
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
563
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
564
565
566
            }
            load2j<<");\n";
        }
567
    }
568
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
569
570
571
572
573
574
575
    map<string, string> defines;
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
576
577
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
578
579
    if (useCutoff && context.getSIMDWidth() < 32)
        defines["PRUNE_BY_CUTOFF"] = "1";
580
581
    defines["FORCE_WORK_GROUP_SIZE"] = context.intToString(forceThreadBlockSize);
    defines["CUTOFF_SQUARED"] = context.doubleToString(cutoff*cutoff);
582
    defines["CUTOFF"] = context.doubleToString(cutoff);
583
584
585
    defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
    defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
586
587
588
589
590
591
592
593
    defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
    int numExclusionTiles = exclusionTiles->getSize();
    defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(numExclusionTiles);
    int numContexts = context.getPlatformData().contexts.size();
    int startExclusionIndex = context.getContextIndex()*numExclusionTiles/numContexts;
    int endExclusionIndex = (context.getContextIndex()+1)*numExclusionTiles/numContexts;
    defines["FIRST_EXCLUSION_TILE"] = context.intToString(startExclusionIndex);
    defines["LAST_EXCLUSION_TILE"] = context.intToString(endExclusionIndex);
594
595
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
596
597
598
599
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else
600
        file = OpenCLKernelSources::nonbonded;
601
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
602
603
604
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
605

606
    int index = 0;
607
608
609
610
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
611
612
613
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
614
    kernel.setArg<cl::Buffer>(index++, exclusionTiles->getDeviceBuffer());
615
    kernel.setArg<cl_uint>(index++, startTileIndex);
616
    kernel.setArg<cl_uint>(index++, numTiles);
617
    if (useCutoff) {
618
619
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
620
        index += 2; // The periodic box size arguments are set when the kernel is executed.
621
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
622
        kernel.setArg<cl::Buffer>(index++, blockCenter->getDeviceBuffer());
623
        kernel.setArg<cl::Buffer>(index++, blockBoundingBox->getDeviceBuffer());
624
        kernel.setArg<cl::Buffer>(index++, interactingAtoms->getDeviceBuffer());
625
    }
626
    for (int i = 0; i < (int) params.size(); i++) {
627
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
628
    }
629
    for (int i = 0; i < (int) arguments.size(); i++) {
630
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
631
    }
632
    return kernel;
633
}