OpenCLNonbondedUtilities.cpp 36.6 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2015 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/OpenMMException.h"
28
29
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
30
#include "OpenCLKernelSources.h"
31
#include "OpenCLExpressionUtilities.h"
32
33
#include "OpenCLSort.h"
#include <algorithm>
34
#include <map>
35
36
#include <set>
#include <utility>
37
38
39
40

using namespace OpenMM;
using namespace std;

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class OpenCLNonbondedUtilities::BlockSortTrait : public OpenCLSort::SortTrait {
public:
    BlockSortTrait(bool useDouble) : useDouble(useDouble) {
    }
    int getDataSize() const {return useDouble ? sizeof(mm_double2) : sizeof(mm_float2);}
    int getKeySize() const {return useDouble ? sizeof(cl_double) : sizeof(cl_float);}
    const char* getDataType() const {return "real2";}
    const char* getKeyType() const {return "real";}
    const char* getMinKey() const {return "-MAXFLOAT";}
    const char* getMaxKey() const {return "MAXFLOAT";}
    const char* getMaxValue() const {return "(real2) (MAXFLOAT, MAXFLOAT)";}
    const char* getSortKey() const {return "value.x";}
private:
    bool useDouble;
};

57
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), useCutoff(false), usePeriodic(false), anyExclusions(false), usePadding(true),
58
59
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusionTiles(NULL), exclusions(NULL), interactingTiles(NULL), interactingAtoms(NULL),
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), sortedBlocks(NULL), sortedBlockCenter(NULL), sortedBlockBoundingBox(NULL),
60
        oldPositions(NULL), rebuildNeighborList(NULL), blockSorter(NULL), forceRebuildNeighborList(true), lastCutoff(0.0), groupFlags(0) {
61
    // Decide how many thread blocks and force buffers to use.
62

63
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
64
65
66
67
68
69
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
70
        if (context.getSupports64BitGlobalAtomics()) {
71
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
72
            forceThreadBlockSize = 256;
73
74
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
75
76
        }
        else {
77
78
79
            numForceThreadBlocks = 3*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 256;
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
80
        }
81
    }
82
    else {
83
        numForceThreadBlocks = context.getNumThreadBlocks();
84
        forceThreadBlockSize = (context.getSIMDWidth() >= 32 ? OpenCLContext::ThreadBlockSize : 32);
85
86
87
88
89
        if (context.getSupports64BitGlobalAtomics()) {
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
        }
        else {
90
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
91
        }
92
    }
93
94
95
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
96
97
98
99
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
100
101
    if (exclusionTiles != NULL)
        delete exclusionTiles;
102
103
    if (exclusions != NULL)
        delete exclusions;
104
105
    if (interactingTiles != NULL)
        delete interactingTiles;
106
107
    if (interactingAtoms != NULL)
        delete interactingAtoms;
108
109
110
111
112
113
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
114
115
116
117
118
119
120
121
122
123
124
125
    if (sortedBlocks != NULL)
        delete sortedBlocks;
    if (sortedBlockCenter != NULL)
        delete sortedBlockCenter;
    if (sortedBlockBoundingBox != NULL)
        delete sortedBlockBoundingBox;
    if (oldPositions != NULL)
        delete oldPositions;
    if (rebuildNeighborList != NULL)
        delete rebuildNeighborList;
    if (blockSorter != NULL)
        delete blockSorter;
126
127
}

128
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
129
    if (groupCutoff.size() > 0) {
130
131
132
133
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
134
135
        if (usesCutoff && groupCutoff.find(forceGroup) != groupCutoff.end() && groupCutoff[forceGroup] != cutoffDistance)
            throw OpenMMException("All Forces in a single force group must use the same cutoff distance");
136
    }
137
138
139
140
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
141
142
143
144
145
146
147
148
149
150
    groupCutoff[forceGroup] = cutoffDistance;
    groupFlags |= 1<<forceGroup;
    if (kernel.size() > 0) {
        if (groupKernelSource.find(forceGroup) == groupKernelSource.end())
            groupKernelSource[forceGroup] = "";
        map<string, string> replacements;
        replacements["CUTOFF"] = "CUTOFF_"+context.intToString(forceGroup);
        replacements["CUTOFF_SQUARED"] = "CUTOFF_"+context.intToString(forceGroup)+"_SQUARED";
        groupKernelSource[forceGroup] += context.replaceStrings(kernel, replacements)+"\n";
    }
151
152
153
154
155
156
157
158
159
160
161
162
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
163
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
164
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
165
166
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
167
168
            set<int> expectedExclusions;
            expectedExclusions.insert(atomExclusions[i].begin(), atomExclusions[i].end());
169
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
170
                if (expectedExclusions.find(exclusionList[i][j]) == expectedExclusions.end())
171
172
173
174
175
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
176
    else {
177
        atomExclusions = exclusionList;
178
179
        anyExclusions = true;
    }
180
181
}

182
static bool compareUshort2(mm_ushort2 a, mm_ushort2 b) {
peastman's avatar
peastman committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    // This version is used on devices with SIMD width of 32 or less.  It sorts tiles to improve cache efficiency.

    return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}

static bool compareUshort2LargeSIMD(mm_ushort2 a, mm_ushort2 b) {
    // This version is used on devices with SIMD width greater than 32.  It puts diagonal tiles before off-diagonal
    // ones to reduce thread divergence.
    
    if (a.x == a.y) {
        if (b.x == b.y)
            return (a.x < b.x);
        return true;
    }
    if (b.x == b.y)
        return false;
199
200
201
    return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}

202
void OpenCLNonbondedUtilities::initialize(const System& system) {
203
204
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
205

206
        atomExclusions.resize(context.getNumAtoms());
207
        for (int i = 0; i < (int) atomExclusions.size(); i++)
208
209
210
            atomExclusions[i].push_back(i);
    }

211
212
213
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
214
    int numContexts = context.getPlatformData().contexts.size();
215
    setAtomBlockRange(context.getContextIndex()/(double) numContexts, (context.getContextIndex()+1)/(double) numContexts);
216

217
    // Build a list of tiles that contain exclusions.
218

219
220
221
222
223
224
225
226
227
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
228
229
230
    vector<mm_ushort2> exclusionTilesVec;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter)
        exclusionTilesVec.push_back(mm_ushort2((unsigned short) iter->first, (unsigned short) iter->second));
peastman's avatar
peastman committed
231
    sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), context.getSIMDWidth() <= 32 ? compareUshort2 : compareUshort2LargeSIMD);
232
233
234
235
236
237
238
239
240
241
242
243
    exclusionTiles = OpenCLArray::create<mm_ushort2>(context, exclusionTilesVec.size(), "exclusionTiles");
    exclusionTiles->upload(exclusionTilesVec);
    map<pair<int, int>, int> exclusionTileMap;
    for (int i = 0; i < (int) exclusionTilesVec.size(); i++) {
        mm_ushort2 tile = exclusionTilesVec[i];
        exclusionTileMap[make_pair(tile.x, tile.y)] = i;
    }
    vector<vector<int> > exclusionBlocksForBlock(numAtomBlocks);
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        exclusionBlocksForBlock[iter->first].push_back(iter->second);
        if (iter->first != iter->second)
            exclusionBlocksForBlock[iter->second].push_back(iter->first);
244
245
246
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
247
248
249
    for (int i = 0; i < numAtomBlocks; i++) {
        exclusionIndicesVec.insert(exclusionIndicesVec.end(), exclusionBlocksForBlock[i].begin(), exclusionBlocksForBlock[i].end());
        exclusionRowIndicesVec[i+1] = exclusionIndicesVec.size();
250
    }
251
252
253
    maxExclusions = 0;
    for (int i = 0; i < (int) exclusionBlocksForBlock.size(); i++)
        maxExclusions = (maxExclusions > exclusionBlocksForBlock[i].size() ? maxExclusions : exclusionBlocksForBlock[i].size());
254
255
    exclusionIndices = OpenCLArray::create<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = OpenCLArray::create<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
256
257
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
258
259
260

    // Record the exclusion data.

261
    exclusions = OpenCLArray::create<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
262
263
    cl_uint allFlags = (cl_uint) -1;
    vector<cl_uint> exclusionVec(exclusions->getSize(), allFlags);
264
265
266
267
268
269
270
271
272
273
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
274
275
                int index = exclusionTileMap[make_pair(x, y)]*OpenCLContext::TileSize;
                exclusionVec[index+offset1] &= allFlags-(1<<offset2);
276
277
            }
            else {
278
279
                int index = exclusionTileMap[make_pair(y, x)]*OpenCLContext::TileSize;
                exclusionVec[index+offset2] &= allFlags-(1<<offset1);
280
281
282
283
284
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
285
286
287
288

    // Create data structures for the neighbor list.

    if (useCutoff) {
289
290
291
292
293
294
295
296
297
        // Select a size for the arrays that hold the neighbor list.  We have to make a fairly
        // arbitrary guess, but if this turns out to be too small we'll increase it later.

        int maxTiles = 20*numAtomBlocks;
        if (maxTiles > numTiles)
            maxTiles = numTiles;
        if (maxTiles < 1)
            maxTiles = 1;
        int numAtoms = context.getNumAtoms();
298
        interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
299
        interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
300
        interactionCount = OpenCLArray::create<cl_uint>(context, 1, "interactionCount");
301
302
303
304
305
306
307
308
309
        int elementSize = (context.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
        blockCenter = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockCenter");
        blockBoundingBox = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockBoundingBox");
        sortedBlocks = new OpenCLArray(context, numAtomBlocks, 2*elementSize, "sortedBlocks");
        sortedBlockCenter = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockCenter");
        sortedBlockBoundingBox = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockBoundingBox");
        oldPositions = new OpenCLArray(context, numAtoms, 4*elementSize, "oldPositions");
        rebuildNeighborList = OpenCLArray::create<int>(context, 1, "rebuildNeighborList");
        blockSorter = new OpenCLSort(context, new BlockSortTrait(context.getUseDoublePrecision()), numAtomBlocks);
310
311
        vector<cl_uint> count(1, 0);
        interactionCount->upload(count);
312
    }
313
314
}

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
static void setPeriodicBoxArgs(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision()) {
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getInvPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecXDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecYDouble());
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxVecZDouble());
    }
    else {
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getInvPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecX());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecY());
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxVecZ());
    }
330
331
}

332
333
334
335
336
337
338
339
340
341
342
343
double OpenCLNonbondedUtilities::getMaxCutoffDistance() {
    double cutoff = 0.0;
    for (map<int, double>::const_iterator iter = groupCutoff.begin(); iter != groupCutoff.end(); ++iter)
        cutoff = max(cutoff, iter->second);
    return cutoff;
}

void OpenCLNonbondedUtilities::prepareInteractions(int forceGroups) {
    if ((forceGroups&groupFlags) == 0)
        return;
    if (groupKernels.find(forceGroups) == groupKernels.end())
        createKernelsForGroups(forceGroups);
344
345
    if (!useCutoff)
        return;
346
347
    if (numTiles == 0)
        return;
348
    KernelSet& kernels = groupKernels[forceGroups];
349
350
    if (usePeriodic) {
        mm_float4 box = context.getPeriodicBoxSize();
351
        double minAllowedSize = 1.999999*kernels.cutoffDistance;
352
353
354
        if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
    }
355
356
357

    // Compute the neighbor list.

358
359
    if (lastCutoff != kernels.cutoffDistance)
        forceRebuildNeighborList = true;
360
361
362
363
364
365
366
367
368
369
370
371
372
    bool rebuild = false;
    do {
        setPeriodicBoxArgs(context, kernels.findBlockBoundsKernel, 1);
        context.executeKernel(kernels.findBlockBoundsKernel, context.getNumAtoms());
        blockSorter->sort(*sortedBlocks);
        kernels.sortBoxDataKernel.setArg<cl_int>(9, forceRebuildNeighborList);
        context.executeKernel(kernels.sortBoxDataKernel, context.getNumAtoms());
        setPeriodicBoxArgs(context, kernels.findInteractingBlocksKernel, 0);
        context.executeKernel(kernels.findInteractingBlocksKernel, context.getNumAtoms(), interactingBlocksThreadBlockSize);
        forceRebuildNeighborList = false;
        if (context.getComputeForceCount() == 1)
            rebuild = updateNeighborListSize(); // This is the first time step, so check whether our initial guess was large enough.
    } while (rebuild);
373
    lastCutoff = kernels.cutoffDistance;
374
375
}

376
void OpenCLNonbondedUtilities::computeInteractions(int forceGroups, bool includeForces, bool includeEnergy) {
377
378
379
380
    if ((forceGroups&groupFlags) == 0)
        return;
    KernelSet& kernels = groupKernels[forceGroups];
    if (kernels.hasForces) {
381
382
383
        cl::Kernel& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel);
        if (*reinterpret_cast<cl_kernel*>(&kernel) == NULL)
            kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy);
384
        if (useCutoff)
385
386
            setPeriodicBoxArgs(context, kernel, 9);
        context.executeKernel(kernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
387
    }
388
389
}

390
bool OpenCLNonbondedUtilities::updateNeighborListSize() {
391
    if (!useCutoff)
392
        return false;
393
394
395
    unsigned int* pinnedInteractionCount = (unsigned int*) context.getPinnedBuffer();
    interactionCount->download(pinnedInteractionCount);
    if (pinnedInteractionCount[0] <= (unsigned int) interactingTiles->getSize())
396
        return false;
397
398
399
400

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

401
402
403
404
    int maxTiles = (int) (1.2*pinnedInteractionCount[0]);
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (maxTiles > totalTiles)
        maxTiles = totalTiles;
405
    delete interactingTiles;
406
407
408
    delete interactingAtoms;
    interactingTiles = NULL; // Avoid an error in the destructor if the following allocation fails
    interactingAtoms = NULL;
409
    interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
410
    interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
411
    for (map<int, KernelSet>::iterator iter = groupKernels.begin(); iter != groupKernels.end(); ++iter) {
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        KernelSet& kernels = iter->second;
        if (*reinterpret_cast<cl_kernel*>(&kernels.forceKernel) != NULL) {
            kernels.forceKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
            kernels.forceKernel.setArg<cl_uint>(14, maxTiles);
            kernels.forceKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        }
        if (*reinterpret_cast<cl_kernel*>(&kernels.energyKernel) != NULL) {
            kernels.energyKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
            kernels.energyKernel.setArg<cl_uint>(14, maxTiles);
            kernels.energyKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        }
        if (*reinterpret_cast<cl_kernel*>(&kernels.forceEnergyKernel) != NULL) {
            kernels.forceEnergyKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
            kernels.forceEnergyKernel.setArg<cl_uint>(14, maxTiles);
            kernels.forceEnergyKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        }
        kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
        kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactingAtoms->getDeviceBuffer());
        kernels.findInteractingBlocksKernel.setArg<cl_uint>(9, maxTiles);
431
432
    }
    forceRebuildNeighborList = true;
433
    return true;
434
435
}

436
437
438
439
440
441
442
443
444
445
446
void OpenCLNonbondedUtilities::setUsePadding(bool padding) {
    usePadding = padding;
}

void OpenCLNonbondedUtilities::setAtomBlockRange(double startFraction, double endFraction) {
    int numAtomBlocks = context.getNumAtomBlocks();
    startBlockIndex = (int) (startFraction*numAtomBlocks);
    numBlocks = (int) (endFraction*numAtomBlocks)-startBlockIndex;
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    startTileIndex = (int) (startFraction*totalTiles);;
    numTiles = (int) (endFraction*totalTiles)-startTileIndex;
447
    if (useCutoff) {
448
        // We are using a cutoff, and the kernels have already been created.
449

450
        for (map<int, KernelSet>::iterator iter = groupKernels.begin(); iter != groupKernels.end(); ++iter) {
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            KernelSet& kernels = iter->second;
            if (*reinterpret_cast<cl_kernel*>(&kernels.forceKernel) != NULL) {
                kernels.forceKernel.setArg<cl_uint>(5, startTileIndex);
                kernels.forceKernel.setArg<cl_uint>(6, numTiles);
            }
            if (*reinterpret_cast<cl_kernel*>(&kernels.energyKernel) != NULL) {
                kernels.energyKernel.setArg<cl_uint>(5, startTileIndex);
                kernels.energyKernel.setArg<cl_uint>(6, numTiles);
            }
            if (*reinterpret_cast<cl_kernel*>(&kernels.forceEnergyKernel) != NULL) {
                kernels.forceEnergyKernel.setArg<cl_uint>(5, startTileIndex);
                kernels.forceEnergyKernel.setArg<cl_uint>(6, numTiles);
            }
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(10, startBlockIndex);
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(11, numBlocks);
466
467
        }
        forceRebuildNeighborList = true;
468
469
470
    }
}

471
472
473
474
475
476
477
478
479
480
481
482
void OpenCLNonbondedUtilities::createKernelsForGroups(int groups) {
    KernelSet kernels;
    double cutoff = 0.0;
    string source;
    for (int i = 0; i < 32; i++) {
        if ((groups&(1<<i)) != 0) {
            cutoff = max(cutoff, groupCutoff[i]);
            source += groupKernelSource[i];
        }
    }
    kernels.hasForces = (source.size() > 0);
    kernels.cutoffDistance = cutoff;
483
    kernels.source = source;
484
    if (useCutoff) {
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        double padding = (usePadding ? 0.1*cutoff : 0.0);
        double paddedCutoff = cutoff+padding;
        map<string, string> defines;
        defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
        defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
        defines["PADDING"] = context.doubleToString(padding);
        defines["PADDED_CUTOFF"] = context.doubleToString(paddedCutoff);
        defines["PADDED_CUTOFF_SQUARED"] = context.doubleToString(paddedCutoff*paddedCutoff);
        defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(exclusionTiles->getSize());
        defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
        defines["SIMD_WIDTH"] = context.intToString(context.getSIMDWidth());
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
        defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
        defines["BUFFER_GROUPS"] = (deviceIsCpu ? "4" : "2");
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        int groupSize = (deviceIsCpu || context.getSIMDWidth() < 32 ? 32 : 256);
        while (true) {
            defines["GROUP_SIZE"] = context.intToString(groupSize);
            cl::Program interactingBlocksProgram = context.createProgram(file, defines);
            kernels.findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
            kernels.findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(6, context.getPosq().getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(7, blockCenter->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(8, blockBoundingBox->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(9, rebuildNeighborList->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(10, sortedBlocks->getDeviceBuffer());
            kernels.sortBoxDataKernel = cl::Kernel(interactingBlocksProgram, "sortBoxData");
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(0, sortedBlocks->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(1, blockCenter->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(2, blockBoundingBox->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(3, sortedBlockCenter->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(4, sortedBlockBoundingBox->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(6, oldPositions->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(7, interactionCount->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(8, rebuildNeighborList->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl_int>(9, true);
            kernels.findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactingAtoms->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(10, startBlockIndex);
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(11, numBlocks);
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(12, sortedBlocks->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(13, sortedBlockCenter->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(14, sortedBlockBoundingBox->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(15, exclusionIndices->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(16, exclusionRowIndices->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(17, oldPositions->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(18, rebuildNeighborList->getDeviceBuffer());
            if (kernels.findInteractingBlocksKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()) < groupSize) {
                // The device can't handle this block size, so reduce it.
540

541
542
543
544
545
546
547
548
549
550
551
552
                groupSize -= 32;
                if (groupSize < 32)
                    throw OpenMMException("Failed to create findInteractingBlocks kernel");
                continue;
            }
            break;
        }
        interactingBlocksThreadBlockSize = (deviceIsCpu ? 1 : groupSize);
    }
    groupKernels[groups] = kernels;
}

553
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy) {
554
555
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
556
    const string suffixes[] = {"x", "y", "z", "w"};
557
    stringstream localData;
558
    int localDataSize = 0;
559
    for (int i = 0; i < (int) params.size(); i++) {
560
561
562
563
564
565
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
566
567
568
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
569
    stringstream args;
570
    for (int i = 0; i < (int) params.size(); i++) {
571
        args << ", __global const ";
572
        args << params[i].getType();
573
        args << "* restrict global_";
574
575
        args << params[i].getName();
    }
576
    for (int i = 0; i < (int) arguments.size(); i++) {
577
578
579
580
581
582
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
583
                args << ", __global const ";
584
585
586
            else
                args << ", __constant ";
            args << arguments[i].getType();
587
            args << "* restrict ";
588
589
            args << arguments[i].getName();
        }
590
    }
591
592
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
593
    for (int i = 0; i < (int) params.size(); i++) {
594
        if (params[i].getNumComponents() == 1) {
595
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
596
597
598
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
599
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
600
        }
601
602
603
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
604
    for (int i = 0; i < (int) params.size(); i++) {
605
        if (params[i].getNumComponents() == 1) {
606
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
607
608
609
610
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
611
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
612
        }
613
614
615
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
616
    for (int i = 0; i < (int) params.size(); i++) {
617
618
619
620
621
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
622
        load1 << "[atom1];\n";
623
624
625
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
626
    for (int i = 0; i < (int) params.size(); i++) {
627
        if (params[i].getNumComponents() == 1) {
628
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
629
630
631
632
633
634
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
635
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
636
637
638
            }
            load2j<<");\n";
        }
639
    }
640
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
641
642
643
644
645
646
647
    map<string, string> defines;
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
648
649
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
650
651
    if (useCutoff && context.getSIMDWidth() < 32)
        defines["PRUNE_BY_CUTOFF"] = "1";
652
653
654
655
    if (includeForces)
        defines["INCLUDE_FORCES"] = "1";
    if (includeEnergy)
        defines["INCLUDE_ENERGY"] = "1";
656
    defines["FORCE_WORK_GROUP_SIZE"] = context.intToString(forceThreadBlockSize);
657
658
659
660
661
662
663
664
665
666
    double maxCutoff = 0.0;
    for (int i = 0; i < 32; i++) {
        if ((groups&(1<<i)) != 0) {
            double cutoff = groupCutoff[i];
            maxCutoff = max(maxCutoff, cutoff);
            defines["CUTOFF_"+context.intToString(i)+"_SQUARED"] = context.doubleToString(cutoff*cutoff);
            defines["CUTOFF_"+context.intToString(i)] = context.doubleToString(cutoff);
        }
    }
    defines["MAX_CUTOFF"] = context.doubleToString(maxCutoff);
667
668
669
    defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
    defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
670
671
672
673
674
675
676
677
    defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
    int numExclusionTiles = exclusionTiles->getSize();
    defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(numExclusionTiles);
    int numContexts = context.getPlatformData().contexts.size();
    int startExclusionIndex = context.getContextIndex()*numExclusionTiles/numContexts;
    int endExclusionIndex = (context.getContextIndex()+1)*numExclusionTiles/numContexts;
    defines["FIRST_EXCLUSION_TILE"] = context.intToString(startExclusionIndex);
    defines["LAST_EXCLUSION_TILE"] = context.intToString(endExclusionIndex);
678
679
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
680
681
682
683
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else
684
        file = OpenCLKernelSources::nonbonded;
685
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
686
687
688
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
689

690
    int index = 0;
691
692
693
694
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
695
696
697
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
698
    kernel.setArg<cl::Buffer>(index++, exclusionTiles->getDeviceBuffer());
699
    kernel.setArg<cl_uint>(index++, startTileIndex);
700
    kernel.setArg<cl_uint>(index++, numTiles);
701
    if (useCutoff) {
702
703
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
704
        index += 5; // The periodic box size arguments are set when the kernel is executed.
705
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
706
        kernel.setArg<cl::Buffer>(index++, blockCenter->getDeviceBuffer());
707
        kernel.setArg<cl::Buffer>(index++, blockBoundingBox->getDeviceBuffer());
708
        kernel.setArg<cl::Buffer>(index++, interactingAtoms->getDeviceBuffer());
709
    }
710
    for (int i = 0; i < (int) params.size(); i++) {
711
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
712
    }
713
    for (int i = 0; i < (int) arguments.size(); i++) {
714
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
715
    }
716
    return kernel;
717
}