"examples/cpp-examples/HelloArgonInC.c" did not exist on "346d3ce3244e9af797a12002ced79cbcb5190a5b"
OpenCLNonbondedUtilities.cpp 34.8 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2015 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/OpenMMException.h"
28
29
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
30
#include "OpenCLKernelSources.h"
31
#include "OpenCLExpressionUtilities.h"
32
33
#include "OpenCLSort.h"
#include <algorithm>
34
#include <map>
35
36
#include <set>
#include <utility>
37
38
39
40

using namespace OpenMM;
using namespace std;

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class OpenCLNonbondedUtilities::BlockSortTrait : public OpenCLSort::SortTrait {
public:
    BlockSortTrait(bool useDouble) : useDouble(useDouble) {
    }
    int getDataSize() const {return useDouble ? sizeof(mm_double2) : sizeof(mm_float2);}
    int getKeySize() const {return useDouble ? sizeof(cl_double) : sizeof(cl_float);}
    const char* getDataType() const {return "real2";}
    const char* getKeyType() const {return "real";}
    const char* getMinKey() const {return "-MAXFLOAT";}
    const char* getMaxKey() const {return "MAXFLOAT";}
    const char* getMaxValue() const {return "(real2) (MAXFLOAT, MAXFLOAT)";}
    const char* getSortKey() const {return "value.x";}
private:
    bool useDouble;
};

57
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), useCutoff(false), usePeriodic(false), anyExclusions(false), usePadding(true),
58
59
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusionTiles(NULL), exclusions(NULL), interactingTiles(NULL), interactingAtoms(NULL),
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), sortedBlocks(NULL), sortedBlockCenter(NULL), sortedBlockBoundingBox(NULL),
60
        oldPositions(NULL), rebuildNeighborList(NULL), blockSorter(NULL), forceRebuildNeighborList(true), lastCutoff(0.0), groupFlags(0) {
61
    // Decide how many thread blocks and force buffers to use.
62

63
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
64
65
66
67
68
69
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
70
        if (context.getSupports64BitGlobalAtomics()) {
71
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
72
            forceThreadBlockSize = 256;
73
74
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
75
76
        }
        else {
77
78
79
            numForceThreadBlocks = 3*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 256;
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
80
        }
81
    }
82
    else {
83
        numForceThreadBlocks = context.getNumThreadBlocks();
84
        forceThreadBlockSize = (context.getSIMDWidth() >= 32 ? OpenCLContext::ThreadBlockSize : 32);
85
86
87
88
89
        if (context.getSupports64BitGlobalAtomics()) {
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
        }
        else {
90
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
91
        }
92
    }
93
94
95
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
96
97
98
99
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
100
101
    if (exclusionTiles != NULL)
        delete exclusionTiles;
102
103
    if (exclusions != NULL)
        delete exclusions;
104
105
    if (interactingTiles != NULL)
        delete interactingTiles;
106
107
    if (interactingAtoms != NULL)
        delete interactingAtoms;
108
109
110
111
112
113
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
114
115
116
117
118
119
120
121
122
123
124
125
    if (sortedBlocks != NULL)
        delete sortedBlocks;
    if (sortedBlockCenter != NULL)
        delete sortedBlockCenter;
    if (sortedBlockBoundingBox != NULL)
        delete sortedBlockBoundingBox;
    if (oldPositions != NULL)
        delete oldPositions;
    if (rebuildNeighborList != NULL)
        delete rebuildNeighborList;
    if (blockSorter != NULL)
        delete blockSorter;
126
127
}

128
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
129
    if (groupCutoff.size() > 0) {
130
131
132
133
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
134
135
        if (usesCutoff && groupCutoff.find(forceGroup) != groupCutoff.end() && groupCutoff[forceGroup] != cutoffDistance)
            throw OpenMMException("All Forces in a single force group must use the same cutoff distance");
136
    }
137
138
139
140
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
141
142
143
144
145
146
147
148
149
150
    groupCutoff[forceGroup] = cutoffDistance;
    groupFlags |= 1<<forceGroup;
    if (kernel.size() > 0) {
        if (groupKernelSource.find(forceGroup) == groupKernelSource.end())
            groupKernelSource[forceGroup] = "";
        map<string, string> replacements;
        replacements["CUTOFF"] = "CUTOFF_"+context.intToString(forceGroup);
        replacements["CUTOFF_SQUARED"] = "CUTOFF_"+context.intToString(forceGroup)+"_SQUARED";
        groupKernelSource[forceGroup] += context.replaceStrings(kernel, replacements)+"\n";
    }
151
152
153
154
155
156
157
158
159
160
161
162
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
163
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
164
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
165
166
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
167
168
            set<int> expectedExclusions;
            expectedExclusions.insert(atomExclusions[i].begin(), atomExclusions[i].end());
169
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
170
                if (expectedExclusions.find(exclusionList[i][j]) == expectedExclusions.end())
171
172
173
174
175
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
176
    else {
177
        atomExclusions = exclusionList;
178
179
        anyExclusions = true;
    }
180
181
}

182
static bool compareUshort2(mm_ushort2 a, mm_ushort2 b) {
peastman's avatar
peastman committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    // This version is used on devices with SIMD width of 32 or less.  It sorts tiles to improve cache efficiency.

    return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}

static bool compareUshort2LargeSIMD(mm_ushort2 a, mm_ushort2 b) {
    // This version is used on devices with SIMD width greater than 32.  It puts diagonal tiles before off-diagonal
    // ones to reduce thread divergence.
    
    if (a.x == a.y) {
        if (b.x == b.y)
            return (a.x < b.x);
        return true;
    }
    if (b.x == b.y)
        return false;
199
200
201
    return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}

202
void OpenCLNonbondedUtilities::initialize(const System& system) {
203
204
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
205

206
        atomExclusions.resize(context.getNumAtoms());
207
        for (int i = 0; i < (int) atomExclusions.size(); i++)
208
209
210
            atomExclusions[i].push_back(i);
    }

211
212
213
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
214
    int numContexts = context.getPlatformData().contexts.size();
215
    setAtomBlockRange(context.getContextIndex()/(double) numContexts, (context.getContextIndex()+1)/(double) numContexts);
216

217
    // Build a list of tiles that contain exclusions.
218

219
220
221
222
223
224
225
226
227
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
228
229
230
    vector<mm_ushort2> exclusionTilesVec;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter)
        exclusionTilesVec.push_back(mm_ushort2((unsigned short) iter->first, (unsigned short) iter->second));
peastman's avatar
peastman committed
231
    sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), context.getSIMDWidth() <= 32 ? compareUshort2 : compareUshort2LargeSIMD);
232
233
234
235
236
237
238
239
240
241
242
243
    exclusionTiles = OpenCLArray::create<mm_ushort2>(context, exclusionTilesVec.size(), "exclusionTiles");
    exclusionTiles->upload(exclusionTilesVec);
    map<pair<int, int>, int> exclusionTileMap;
    for (int i = 0; i < (int) exclusionTilesVec.size(); i++) {
        mm_ushort2 tile = exclusionTilesVec[i];
        exclusionTileMap[make_pair(tile.x, tile.y)] = i;
    }
    vector<vector<int> > exclusionBlocksForBlock(numAtomBlocks);
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        exclusionBlocksForBlock[iter->first].push_back(iter->second);
        if (iter->first != iter->second)
            exclusionBlocksForBlock[iter->second].push_back(iter->first);
244
245
246
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
247
248
249
    for (int i = 0; i < numAtomBlocks; i++) {
        exclusionIndicesVec.insert(exclusionIndicesVec.end(), exclusionBlocksForBlock[i].begin(), exclusionBlocksForBlock[i].end());
        exclusionRowIndicesVec[i+1] = exclusionIndicesVec.size();
250
    }
251
252
253
    maxExclusions = 0;
    for (int i = 0; i < (int) exclusionBlocksForBlock.size(); i++)
        maxExclusions = (maxExclusions > exclusionBlocksForBlock[i].size() ? maxExclusions : exclusionBlocksForBlock[i].size());
254
255
    exclusionIndices = OpenCLArray::create<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = OpenCLArray::create<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
256
257
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
258
259
260

    // Record the exclusion data.

261
    exclusions = OpenCLArray::create<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
262
263
    cl_uint allFlags = (cl_uint) -1;
    vector<cl_uint> exclusionVec(exclusions->getSize(), allFlags);
264
265
266
267
268
269
270
271
272
273
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
274
275
                int index = exclusionTileMap[make_pair(x, y)]*OpenCLContext::TileSize;
                exclusionVec[index+offset1] &= allFlags-(1<<offset2);
276
277
            }
            else {
278
279
                int index = exclusionTileMap[make_pair(y, x)]*OpenCLContext::TileSize;
                exclusionVec[index+offset2] &= allFlags-(1<<offset1);
280
281
282
283
284
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
285
286
287
288

    // Create data structures for the neighbor list.

    if (useCutoff) {
289
290
291
292
293
294
295
296
297
        // Select a size for the arrays that hold the neighbor list.  We have to make a fairly
        // arbitrary guess, but if this turns out to be too small we'll increase it later.

        int maxTiles = 20*numAtomBlocks;
        if (maxTiles > numTiles)
            maxTiles = numTiles;
        if (maxTiles < 1)
            maxTiles = 1;
        int numAtoms = context.getNumAtoms();
298
        interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
299
        interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
300
        interactionCount = OpenCLArray::create<cl_uint>(context, 1, "interactionCount");
301
302
303
304
305
306
307
308
309
        int elementSize = (context.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
        blockCenter = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockCenter");
        blockBoundingBox = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockBoundingBox");
        sortedBlocks = new OpenCLArray(context, numAtomBlocks, 2*elementSize, "sortedBlocks");
        sortedBlockCenter = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockCenter");
        sortedBlockBoundingBox = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockBoundingBox");
        oldPositions = new OpenCLArray(context, numAtoms, 4*elementSize, "oldPositions");
        rebuildNeighborList = OpenCLArray::create<int>(context, 1, "rebuildNeighborList");
        blockSorter = new OpenCLSort(context, new BlockSortTrait(context.getUseDoublePrecision()), numAtomBlocks);
310
311
        vector<cl_uint> count(1, 0);
        interactionCount->upload(count);
312
    }
313
314
}

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
static void setPeriodicBoxArgs(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision()) {
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getInvPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecXDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecYDouble());
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxVecZDouble());
    }
    else {
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getInvPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecX());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecY());
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxVecZ());
    }
330
331
}

332
333
334
335
336
337
338
339
340
341
342
343
double OpenCLNonbondedUtilities::getMaxCutoffDistance() {
    double cutoff = 0.0;
    for (map<int, double>::const_iterator iter = groupCutoff.begin(); iter != groupCutoff.end(); ++iter)
        cutoff = max(cutoff, iter->second);
    return cutoff;
}

void OpenCLNonbondedUtilities::prepareInteractions(int forceGroups) {
    if ((forceGroups&groupFlags) == 0)
        return;
    if (groupKernels.find(forceGroups) == groupKernels.end())
        createKernelsForGroups(forceGroups);
344
345
    if (!useCutoff)
        return;
346
347
    if (numTiles == 0)
        return;
348
    KernelSet& kernels = groupKernels[forceGroups];
349
350
    if (usePeriodic) {
        mm_float4 box = context.getPeriodicBoxSize();
351
        double minAllowedSize = 1.999999*kernels.cutoffDistance;
352
353
354
        if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
    }
355
356
357

    // Compute the neighbor list.

358
359
    if (lastCutoff != kernels.cutoffDistance)
        forceRebuildNeighborList = true;
360
361
362
363
364
365
366
367
368
369
370
371
372
    bool rebuild = false;
    do {
        setPeriodicBoxArgs(context, kernels.findBlockBoundsKernel, 1);
        context.executeKernel(kernels.findBlockBoundsKernel, context.getNumAtoms());
        blockSorter->sort(*sortedBlocks);
        kernels.sortBoxDataKernel.setArg<cl_int>(9, forceRebuildNeighborList);
        context.executeKernel(kernels.sortBoxDataKernel, context.getNumAtoms());
        setPeriodicBoxArgs(context, kernels.findInteractingBlocksKernel, 0);
        context.executeKernel(kernels.findInteractingBlocksKernel, context.getNumAtoms(), interactingBlocksThreadBlockSize);
        forceRebuildNeighborList = false;
        if (context.getComputeForceCount() == 1)
            rebuild = updateNeighborListSize(); // This is the first time step, so check whether our initial guess was large enough.
    } while (rebuild);
373
    lastCutoff = kernels.cutoffDistance;
374
375
}

376
377
378
379
380
void OpenCLNonbondedUtilities::computeInteractions(int forceGroups) {
    if ((forceGroups&groupFlags) == 0)
        return;
    KernelSet& kernels = groupKernels[forceGroups];
    if (kernels.hasForces) {
381
        if (useCutoff)
382
383
            setPeriodicBoxArgs(context, kernels.forceKernel, 9);
        context.executeKernel(kernels.forceKernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
384
    }
385
386
}

387
bool OpenCLNonbondedUtilities::updateNeighborListSize() {
388
    if (!useCutoff)
389
        return false;
390
391
392
    unsigned int* pinnedInteractionCount = (unsigned int*) context.getPinnedBuffer();
    interactionCount->download(pinnedInteractionCount);
    if (pinnedInteractionCount[0] <= (unsigned int) interactingTiles->getSize())
393
        return false;
394
395
396
397

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

398
399
400
401
    int maxTiles = (int) (1.2*pinnedInteractionCount[0]);
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (maxTiles > totalTiles)
        maxTiles = totalTiles;
402
    delete interactingTiles;
403
404
405
    delete interactingAtoms;
    interactingTiles = NULL; // Avoid an error in the destructor if the following allocation fails
    interactingAtoms = NULL;
406
    interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
407
    interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
408
409
410
411
412
413
414
415
416
    for (map<int, KernelSet>::iterator iter = groupKernels.begin(); iter != groupKernels.end(); ++iter) {
        iter->second.forceKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
        iter->second.forceKernel.setArg<cl_uint>(14, maxTiles);
        iter->second.forceKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        iter->second.findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
        iter->second.findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactingAtoms->getDeviceBuffer());
        iter->second.findInteractingBlocksKernel.setArg<cl_uint>(9, maxTiles);
    }
    forceRebuildNeighborList = true;
417
    return true;
418
419
}

420
421
422
423
424
425
426
427
428
429
430
void OpenCLNonbondedUtilities::setUsePadding(bool padding) {
    usePadding = padding;
}

void OpenCLNonbondedUtilities::setAtomBlockRange(double startFraction, double endFraction) {
    int numAtomBlocks = context.getNumAtomBlocks();
    startBlockIndex = (int) (startFraction*numAtomBlocks);
    numBlocks = (int) (endFraction*numAtomBlocks)-startBlockIndex;
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    startTileIndex = (int) (startFraction*totalTiles);;
    numTiles = (int) (endFraction*totalTiles)-startTileIndex;
431
    if (useCutoff) {
432
        // We are using a cutoff, and the kernels have already been created.
433

434
435
436
437
438
439
440
        for (map<int, KernelSet>::iterator iter = groupKernels.begin(); iter != groupKernels.end(); ++iter) {
            iter->second.forceKernel.setArg<cl_uint>(5, startTileIndex);
            iter->second.forceKernel.setArg<cl_uint>(6, numTiles);
            iter->second.findInteractingBlocksKernel.setArg<cl_uint>(10, startBlockIndex);
            iter->second.findInteractingBlocksKernel.setArg<cl_uint>(11, numBlocks);
        }
        forceRebuildNeighborList = true;
441
442
443
    }
}

444
445
446
447
448
449
450
451
452
453
454
455
456
457
void OpenCLNonbondedUtilities::createKernelsForGroups(int groups) {
    KernelSet kernels;
    double cutoff = 0.0;
    string source;
    for (int i = 0; i < 32; i++) {
        if ((groups&(1<<i)) != 0) {
            cutoff = max(cutoff, groupCutoff[i]);
            source += groupKernelSource[i];
        }
    }
    kernels.hasForces = (source.size() > 0);
    kernels.cutoffDistance = cutoff;
    if (kernels.hasForces)
        kernels.forceKernel = createInteractionKernel(source, parameters, arguments, true, true, groups);
458
    if (useCutoff) {
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        double padding = (usePadding ? 0.1*cutoff : 0.0);
        double paddedCutoff = cutoff+padding;
        map<string, string> defines;
        defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
        defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
        defines["PADDING"] = context.doubleToString(padding);
        defines["PADDED_CUTOFF"] = context.doubleToString(paddedCutoff);
        defines["PADDED_CUTOFF_SQUARED"] = context.doubleToString(paddedCutoff*paddedCutoff);
        defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(exclusionTiles->getSize());
        defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
        defines["SIMD_WIDTH"] = context.intToString(context.getSIMDWidth());
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
        defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
        defines["BUFFER_GROUPS"] = (deviceIsCpu ? "4" : "2");
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        int groupSize = (deviceIsCpu || context.getSIMDWidth() < 32 ? 32 : 256);
        while (true) {
            defines["GROUP_SIZE"] = context.intToString(groupSize);
            cl::Program interactingBlocksProgram = context.createProgram(file, defines);
            kernels.findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
            kernels.findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(6, context.getPosq().getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(7, blockCenter->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(8, blockBoundingBox->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(9, rebuildNeighborList->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(10, sortedBlocks->getDeviceBuffer());
            kernels.sortBoxDataKernel = cl::Kernel(interactingBlocksProgram, "sortBoxData");
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(0, sortedBlocks->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(1, blockCenter->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(2, blockBoundingBox->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(3, sortedBlockCenter->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(4, sortedBlockBoundingBox->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(6, oldPositions->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(7, interactionCount->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(8, rebuildNeighborList->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl_int>(9, true);
            kernels.findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactingAtoms->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(10, startBlockIndex);
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(11, numBlocks);
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(12, sortedBlocks->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(13, sortedBlockCenter->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(14, sortedBlockBoundingBox->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(15, exclusionIndices->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(16, exclusionRowIndices->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(17, oldPositions->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(18, rebuildNeighborList->getDeviceBuffer());
            if (kernels.findInteractingBlocksKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()) < groupSize) {
                // The device can't handle this block size, so reduce it.
514

515
516
517
518
519
520
521
522
523
524
525
526
527
                groupSize -= 32;
                if (groupSize < 32)
                    throw OpenMMException("Failed to create findInteractingBlocks kernel");
                continue;
            }
            break;
        }
        interactingBlocksThreadBlockSize = (deviceIsCpu ? 1 : groupSize);
    }
    groupKernels[groups] = kernels;
}

cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups) {
528
529
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
530
    const string suffixes[] = {"x", "y", "z", "w"};
531
    stringstream localData;
532
    int localDataSize = 0;
533
    for (int i = 0; i < (int) params.size(); i++) {
534
535
536
537
538
539
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
540
541
542
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
543
    stringstream args;
544
    for (int i = 0; i < (int) params.size(); i++) {
545
        args << ", __global const ";
546
        args << params[i].getType();
547
        args << "* restrict global_";
548
549
        args << params[i].getName();
    }
550
    for (int i = 0; i < (int) arguments.size(); i++) {
551
552
553
554
555
556
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
557
                args << ", __global const ";
558
559
560
            else
                args << ", __constant ";
            args << arguments[i].getType();
561
            args << "* restrict ";
562
563
            args << arguments[i].getName();
        }
564
    }
565
566
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
567
    for (int i = 0; i < (int) params.size(); i++) {
568
        if (params[i].getNumComponents() == 1) {
569
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
570
571
572
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
573
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
574
        }
575
576
577
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
578
    for (int i = 0; i < (int) params.size(); i++) {
579
        if (params[i].getNumComponents() == 1) {
580
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
581
582
583
584
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
585
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
586
        }
587
588
589
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
590
    for (int i = 0; i < (int) params.size(); i++) {
591
592
593
594
595
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
596
        load1 << "[atom1];\n";
597
598
599
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
600
    for (int i = 0; i < (int) params.size(); i++) {
601
        if (params[i].getNumComponents() == 1) {
602
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
603
604
605
606
607
608
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
609
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
610
611
612
            }
            load2j<<");\n";
        }
613
    }
614
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
615
616
617
618
619
620
621
    map<string, string> defines;
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
622
623
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
624
625
    if (useCutoff && context.getSIMDWidth() < 32)
        defines["PRUNE_BY_CUTOFF"] = "1";
626
    defines["FORCE_WORK_GROUP_SIZE"] = context.intToString(forceThreadBlockSize);
627
628
629
630
631
632
633
634
635
636
    double maxCutoff = 0.0;
    for (int i = 0; i < 32; i++) {
        if ((groups&(1<<i)) != 0) {
            double cutoff = groupCutoff[i];
            maxCutoff = max(maxCutoff, cutoff);
            defines["CUTOFF_"+context.intToString(i)+"_SQUARED"] = context.doubleToString(cutoff*cutoff);
            defines["CUTOFF_"+context.intToString(i)] = context.doubleToString(cutoff);
        }
    }
    defines["MAX_CUTOFF"] = context.doubleToString(maxCutoff);
637
638
639
    defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
    defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
640
641
642
643
644
645
646
647
    defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
    int numExclusionTiles = exclusionTiles->getSize();
    defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(numExclusionTiles);
    int numContexts = context.getPlatformData().contexts.size();
    int startExclusionIndex = context.getContextIndex()*numExclusionTiles/numContexts;
    int endExclusionIndex = (context.getContextIndex()+1)*numExclusionTiles/numContexts;
    defines["FIRST_EXCLUSION_TILE"] = context.intToString(startExclusionIndex);
    defines["LAST_EXCLUSION_TILE"] = context.intToString(endExclusionIndex);
648
649
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
650
651
652
653
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else
654
        file = OpenCLKernelSources::nonbonded;
655
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
656
657
658
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
659

660
    int index = 0;
661
662
663
664
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
665
666
667
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
668
    kernel.setArg<cl::Buffer>(index++, exclusionTiles->getDeviceBuffer());
669
    kernel.setArg<cl_uint>(index++, startTileIndex);
670
    kernel.setArg<cl_uint>(index++, numTiles);
671
    if (useCutoff) {
672
673
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
674
        index += 5; // The periodic box size arguments are set when the kernel is executed.
675
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
676
        kernel.setArg<cl::Buffer>(index++, blockCenter->getDeviceBuffer());
677
        kernel.setArg<cl::Buffer>(index++, blockBoundingBox->getDeviceBuffer());
678
        kernel.setArg<cl::Buffer>(index++, interactingAtoms->getDeviceBuffer());
679
    }
680
    for (int i = 0; i < (int) params.size(); i++) {
681
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
682
    }
683
    for (int i = 0; i < (int) arguments.size(); i++) {
684
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
685
    }
686
    return kernel;
687
}