OpenCLNonbondedUtilities.cpp 38.3 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2016 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/OpenMMException.h"
28
29
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
30
#include "OpenCLKernelSources.h"
31
#include "OpenCLExpressionUtilities.h"
32
33
#include "OpenCLSort.h"
#include <algorithm>
34
#include <map>
35
36
#include <set>
#include <utility>
37
38
39
40

using namespace OpenMM;
using namespace std;

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class OpenCLNonbondedUtilities::BlockSortTrait : public OpenCLSort::SortTrait {
public:
    BlockSortTrait(bool useDouble) : useDouble(useDouble) {
    }
    int getDataSize() const {return useDouble ? sizeof(mm_double2) : sizeof(mm_float2);}
    int getKeySize() const {return useDouble ? sizeof(cl_double) : sizeof(cl_float);}
    const char* getDataType() const {return "real2";}
    const char* getKeyType() const {return "real";}
    const char* getMinKey() const {return "-MAXFLOAT";}
    const char* getMaxKey() const {return "MAXFLOAT";}
    const char* getMaxValue() const {return "(real2) (MAXFLOAT, MAXFLOAT)";}
    const char* getSortKey() const {return "value.x";}
private:
    bool useDouble;
};

57
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), useCutoff(false), usePeriodic(false), anyExclusions(false), usePadding(true),
58
59
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusionTiles(NULL), exclusions(NULL), interactingTiles(NULL), interactingAtoms(NULL),
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), sortedBlocks(NULL), sortedBlockCenter(NULL), sortedBlockBoundingBox(NULL),
60
        oldPositions(NULL), rebuildNeighborList(NULL), blockSorter(NULL), pinnedCountBuffer(NULL), pinnedCountMemory(NULL), forceRebuildNeighborList(true), lastCutoff(0.0), groupFlags(0) {
61
    // Decide how many thread blocks and force buffers to use.
62

63
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
64
65
66
67
68
69
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
70
        if (context.getSupports64BitGlobalAtomics()) {
71
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
72
            forceThreadBlockSize = 256;
73
74
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
75
76
        }
        else {
77
78
79
            numForceThreadBlocks = 3*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 256;
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
80
        }
81
    }
82
    else {
83
        numForceThreadBlocks = context.getNumThreadBlocks();
84
        forceThreadBlockSize = (context.getSIMDWidth() >= 32 ? OpenCLContext::ThreadBlockSize : 32);
85
86
87
88
89
        if (context.getSupports64BitGlobalAtomics()) {
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
        }
        else {
90
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
91
        }
92
    }
93
94
    pinnedCountBuffer = new cl::Buffer(context.getContext(), CL_MEM_ALLOC_HOST_PTR, sizeof(int));
    pinnedCountMemory = (int*) context.getQueue().enqueueMapBuffer(*pinnedCountBuffer, CL_TRUE, CL_MAP_READ, 0, sizeof(int));
95
96
97
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
98
99
100
101
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
102
103
    if (exclusionTiles != NULL)
        delete exclusionTiles;
104
105
    if (exclusions != NULL)
        delete exclusions;
106
107
    if (interactingTiles != NULL)
        delete interactingTiles;
108
109
    if (interactingAtoms != NULL)
        delete interactingAtoms;
110
111
112
113
114
115
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
116
117
118
119
120
121
122
123
124
125
126
127
    if (sortedBlocks != NULL)
        delete sortedBlocks;
    if (sortedBlockCenter != NULL)
        delete sortedBlockCenter;
    if (sortedBlockBoundingBox != NULL)
        delete sortedBlockBoundingBox;
    if (oldPositions != NULL)
        delete oldPositions;
    if (rebuildNeighborList != NULL)
        delete rebuildNeighborList;
    if (blockSorter != NULL)
        delete blockSorter;
128
129
    if (pinnedCountBuffer != NULL)
        delete pinnedCountBuffer;
130
131
}

132
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
133
    if (groupCutoff.size() > 0) {
134
135
136
137
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
138
139
        if (usesCutoff && groupCutoff.find(forceGroup) != groupCutoff.end() && groupCutoff[forceGroup] != cutoffDistance)
            throw OpenMMException("All Forces in a single force group must use the same cutoff distance");
140
    }
141
142
143
144
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
145
146
147
148
149
150
151
152
153
154
    groupCutoff[forceGroup] = cutoffDistance;
    groupFlags |= 1<<forceGroup;
    if (kernel.size() > 0) {
        if (groupKernelSource.find(forceGroup) == groupKernelSource.end())
            groupKernelSource[forceGroup] = "";
        map<string, string> replacements;
        replacements["CUTOFF"] = "CUTOFF_"+context.intToString(forceGroup);
        replacements["CUTOFF_SQUARED"] = "CUTOFF_"+context.intToString(forceGroup)+"_SQUARED";
        groupKernelSource[forceGroup] += context.replaceStrings(kernel, replacements)+"\n";
    }
155
156
157
158
159
160
161
162
163
164
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

165
166
167
168
169
170
171
172
173
174
175
176
177
string OpenCLNonbondedUtilities::addEnergyParameterDerivative(const string& param) {
    // See if the parameter has already been added.
    
    int index;
    for (index = 0; index < energyParameterDerivatives.size(); index++)
        if (param == energyParameterDerivatives[index])
            break;
    if (index == energyParameterDerivatives.size())
        energyParameterDerivatives.push_back(param);
    context.addEnergyParameterDerivative(param);
    return string("energyParamDeriv")+context.intToString(index);
}

178
179
void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
180
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
181
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
182
183
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
184
185
            set<int> expectedExclusions;
            expectedExclusions.insert(atomExclusions[i].begin(), atomExclusions[i].end());
186
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
187
                if (expectedExclusions.find(exclusionList[i][j]) == expectedExclusions.end())
188
189
190
191
192
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
193
    else {
194
        atomExclusions = exclusionList;
195
196
        anyExclusions = true;
    }
197
198
}

199
static bool compareUshort2(mm_ushort2 a, mm_ushort2 b) {
peastman's avatar
peastman committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    // This version is used on devices with SIMD width of 32 or less.  It sorts tiles to improve cache efficiency.

    return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}

static bool compareUshort2LargeSIMD(mm_ushort2 a, mm_ushort2 b) {
    // This version is used on devices with SIMD width greater than 32.  It puts diagonal tiles before off-diagonal
    // ones to reduce thread divergence.
    
    if (a.x == a.y) {
        if (b.x == b.y)
            return (a.x < b.x);
        return true;
    }
    if (b.x == b.y)
        return false;
216
217
218
    return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}

219
void OpenCLNonbondedUtilities::initialize(const System& system) {
220
221
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
222

223
        atomExclusions.resize(context.getNumAtoms());
224
        for (int i = 0; i < (int) atomExclusions.size(); i++)
225
226
227
            atomExclusions[i].push_back(i);
    }

228
229
230
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
231
    int numContexts = context.getPlatformData().contexts.size();
232
    setAtomBlockRange(context.getContextIndex()/(double) numContexts, (context.getContextIndex()+1)/(double) numContexts);
233

234
    // Build a list of tiles that contain exclusions.
235

236
237
238
239
240
241
242
243
244
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
245
246
247
    vector<mm_ushort2> exclusionTilesVec;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter)
        exclusionTilesVec.push_back(mm_ushort2((unsigned short) iter->first, (unsigned short) iter->second));
peastman's avatar
peastman committed
248
    sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), context.getSIMDWidth() <= 32 ? compareUshort2 : compareUshort2LargeSIMD);
249
250
251
252
253
254
255
256
257
258
259
260
    exclusionTiles = OpenCLArray::create<mm_ushort2>(context, exclusionTilesVec.size(), "exclusionTiles");
    exclusionTiles->upload(exclusionTilesVec);
    map<pair<int, int>, int> exclusionTileMap;
    for (int i = 0; i < (int) exclusionTilesVec.size(); i++) {
        mm_ushort2 tile = exclusionTilesVec[i];
        exclusionTileMap[make_pair(tile.x, tile.y)] = i;
    }
    vector<vector<int> > exclusionBlocksForBlock(numAtomBlocks);
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        exclusionBlocksForBlock[iter->first].push_back(iter->second);
        if (iter->first != iter->second)
            exclusionBlocksForBlock[iter->second].push_back(iter->first);
261
262
263
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
264
265
266
    for (int i = 0; i < numAtomBlocks; i++) {
        exclusionIndicesVec.insert(exclusionIndicesVec.end(), exclusionBlocksForBlock[i].begin(), exclusionBlocksForBlock[i].end());
        exclusionRowIndicesVec[i+1] = exclusionIndicesVec.size();
267
    }
268
269
270
    maxExclusions = 0;
    for (int i = 0; i < (int) exclusionBlocksForBlock.size(); i++)
        maxExclusions = (maxExclusions > exclusionBlocksForBlock[i].size() ? maxExclusions : exclusionBlocksForBlock[i].size());
271
272
    exclusionIndices = OpenCLArray::create<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = OpenCLArray::create<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
273
274
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
275
276
277

    // Record the exclusion data.

278
    exclusions = OpenCLArray::create<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
279
280
    cl_uint allFlags = (cl_uint) -1;
    vector<cl_uint> exclusionVec(exclusions->getSize(), allFlags);
281
282
283
284
285
286
287
288
289
290
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
291
292
                int index = exclusionTileMap[make_pair(x, y)]*OpenCLContext::TileSize;
                exclusionVec[index+offset1] &= allFlags-(1<<offset2);
293
294
            }
            else {
295
296
                int index = exclusionTileMap[make_pair(y, x)]*OpenCLContext::TileSize;
                exclusionVec[index+offset2] &= allFlags-(1<<offset1);
297
298
299
300
301
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
302
303
304
305

    // Create data structures for the neighbor list.

    if (useCutoff) {
306
307
308
309
310
311
312
313
314
        // Select a size for the arrays that hold the neighbor list.  We have to make a fairly
        // arbitrary guess, but if this turns out to be too small we'll increase it later.

        int maxTiles = 20*numAtomBlocks;
        if (maxTiles > numTiles)
            maxTiles = numTiles;
        if (maxTiles < 1)
            maxTiles = 1;
        int numAtoms = context.getNumAtoms();
315
        interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
316
        interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
317
        interactionCount = OpenCLArray::create<cl_uint>(context, 1, "interactionCount");
318
319
320
321
322
323
324
325
326
        int elementSize = (context.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
        blockCenter = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockCenter");
        blockBoundingBox = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockBoundingBox");
        sortedBlocks = new OpenCLArray(context, numAtomBlocks, 2*elementSize, "sortedBlocks");
        sortedBlockCenter = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockCenter");
        sortedBlockBoundingBox = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockBoundingBox");
        oldPositions = new OpenCLArray(context, numAtoms, 4*elementSize, "oldPositions");
        rebuildNeighborList = OpenCLArray::create<int>(context, 1, "rebuildNeighborList");
        blockSorter = new OpenCLSort(context, new BlockSortTrait(context.getUseDoublePrecision()), numAtomBlocks);
327
328
        vector<cl_uint> count(1, 0);
        interactionCount->upload(count);
329
    }
330
331
}

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
static void setPeriodicBoxArgs(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision()) {
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getInvPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecXDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecYDouble());
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxVecZDouble());
    }
    else {
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getInvPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecX());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecY());
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxVecZ());
    }
347
348
}

349
350
351
352
353
354
355
356
357
358
359
360
double OpenCLNonbondedUtilities::getMaxCutoffDistance() {
    double cutoff = 0.0;
    for (map<int, double>::const_iterator iter = groupCutoff.begin(); iter != groupCutoff.end(); ++iter)
        cutoff = max(cutoff, iter->second);
    return cutoff;
}

void OpenCLNonbondedUtilities::prepareInteractions(int forceGroups) {
    if ((forceGroups&groupFlags) == 0)
        return;
    if (groupKernels.find(forceGroups) == groupKernels.end())
        createKernelsForGroups(forceGroups);
361
362
    if (!useCutoff)
        return;
363
364
    if (numTiles == 0)
        return;
365
    KernelSet& kernels = groupKernels[forceGroups];
366
367
    if (usePeriodic) {
        mm_float4 box = context.getPeriodicBoxSize();
368
        double minAllowedSize = 1.999999*kernels.cutoffDistance;
369
370
371
        if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
    }
372
373
374

    // Compute the neighbor list.

375
376
    if (lastCutoff != kernels.cutoffDistance)
        forceRebuildNeighborList = true;
377
378
379
380
381
382
383
384
    setPeriodicBoxArgs(context, kernels.findBlockBoundsKernel, 1);
    context.executeKernel(kernels.findBlockBoundsKernel, context.getNumAtoms());
    blockSorter->sort(*sortedBlocks);
    kernels.sortBoxDataKernel.setArg<cl_int>(9, forceRebuildNeighborList);
    context.executeKernel(kernels.sortBoxDataKernel, context.getNumAtoms());
    setPeriodicBoxArgs(context, kernels.findInteractingBlocksKernel, 0);
    context.executeKernel(kernels.findInteractingBlocksKernel, context.getNumAtoms(), interactingBlocksThreadBlockSize);
    forceRebuildNeighborList = false;
385
    lastCutoff = kernels.cutoffDistance;
386
    context.getQueue().enqueueReadBuffer(interactionCount->getDeviceBuffer(), CL_FALSE, 0, sizeof(int), pinnedCountMemory, NULL, &downloadCountEvent); 
387
388
}

389
void OpenCLNonbondedUtilities::computeInteractions(int forceGroups, bool includeForces, bool includeEnergy) {
390
391
392
393
    if ((forceGroups&groupFlags) == 0)
        return;
    KernelSet& kernels = groupKernels[forceGroups];
    if (kernels.hasForces) {
394
395
396
        cl::Kernel& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel);
        if (*reinterpret_cast<cl_kernel*>(&kernel) == NULL)
            kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy);
397
        if (useCutoff)
398
399
            setPeriodicBoxArgs(context, kernel, 9);
        context.executeKernel(kernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
400
    }
401
402
403
404
    if (useCutoff && numTiles > 0) {
        downloadCountEvent.wait();
        updateNeighborListSize();
    }
405
406
}

407
bool OpenCLNonbondedUtilities::updateNeighborListSize() {
408
    if (!useCutoff)
409
        return false;
410
    if (pinnedCountMemory[0] <= (unsigned int) interactingTiles->getSize())
411
        return false;
412
413
414
415

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

416
    int maxTiles = (int) (1.2*pinnedCountMemory[0]);
417
418
419
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (maxTiles > totalTiles)
        maxTiles = totalTiles;
420
    delete interactingTiles;
421
422
423
    delete interactingAtoms;
    interactingTiles = NULL; // Avoid an error in the destructor if the following allocation fails
    interactingAtoms = NULL;
424
    interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
425
    interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
426
    for (map<int, KernelSet>::iterator iter = groupKernels.begin(); iter != groupKernels.end(); ++iter) {
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        KernelSet& kernels = iter->second;
        if (*reinterpret_cast<cl_kernel*>(&kernels.forceKernel) != NULL) {
            kernels.forceKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
            kernels.forceKernel.setArg<cl_uint>(14, maxTiles);
            kernels.forceKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        }
        if (*reinterpret_cast<cl_kernel*>(&kernels.energyKernel) != NULL) {
            kernels.energyKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
            kernels.energyKernel.setArg<cl_uint>(14, maxTiles);
            kernels.energyKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        }
        if (*reinterpret_cast<cl_kernel*>(&kernels.forceEnergyKernel) != NULL) {
            kernels.forceEnergyKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
            kernels.forceEnergyKernel.setArg<cl_uint>(14, maxTiles);
            kernels.forceEnergyKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        }
        kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
        kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactingAtoms->getDeviceBuffer());
        kernels.findInteractingBlocksKernel.setArg<cl_uint>(9, maxTiles);
446
447
    }
    forceRebuildNeighborList = true;
448
    context.setForcesValid(false);
449
    return true;
450
451
}

452
453
454
455
456
457
458
459
460
461
462
void OpenCLNonbondedUtilities::setUsePadding(bool padding) {
    usePadding = padding;
}

void OpenCLNonbondedUtilities::setAtomBlockRange(double startFraction, double endFraction) {
    int numAtomBlocks = context.getNumAtomBlocks();
    startBlockIndex = (int) (startFraction*numAtomBlocks);
    numBlocks = (int) (endFraction*numAtomBlocks)-startBlockIndex;
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    startTileIndex = (int) (startFraction*totalTiles);;
    numTiles = (int) (endFraction*totalTiles)-startTileIndex;
463
    if (useCutoff) {
464
        // We are using a cutoff, and the kernels have already been created.
465

466
        for (map<int, KernelSet>::iterator iter = groupKernels.begin(); iter != groupKernels.end(); ++iter) {
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
            KernelSet& kernels = iter->second;
            if (*reinterpret_cast<cl_kernel*>(&kernels.forceKernel) != NULL) {
                kernels.forceKernel.setArg<cl_uint>(5, startTileIndex);
                kernels.forceKernel.setArg<cl_uint>(6, numTiles);
            }
            if (*reinterpret_cast<cl_kernel*>(&kernels.energyKernel) != NULL) {
                kernels.energyKernel.setArg<cl_uint>(5, startTileIndex);
                kernels.energyKernel.setArg<cl_uint>(6, numTiles);
            }
            if (*reinterpret_cast<cl_kernel*>(&kernels.forceEnergyKernel) != NULL) {
                kernels.forceEnergyKernel.setArg<cl_uint>(5, startTileIndex);
                kernels.forceEnergyKernel.setArg<cl_uint>(6, numTiles);
            }
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(10, startBlockIndex);
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(11, numBlocks);
482
483
        }
        forceRebuildNeighborList = true;
484
485
486
    }
}

487
488
489
490
491
492
493
494
495
496
497
498
void OpenCLNonbondedUtilities::createKernelsForGroups(int groups) {
    KernelSet kernels;
    double cutoff = 0.0;
    string source;
    for (int i = 0; i < 32; i++) {
        if ((groups&(1<<i)) != 0) {
            cutoff = max(cutoff, groupCutoff[i]);
            source += groupKernelSource[i];
        }
    }
    kernels.hasForces = (source.size() > 0);
    kernels.cutoffDistance = cutoff;
499
    kernels.source = source;
500
    if (useCutoff) {
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        double padding = (usePadding ? 0.1*cutoff : 0.0);
        double paddedCutoff = cutoff+padding;
        map<string, string> defines;
        defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
        defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
        defines["PADDING"] = context.doubleToString(padding);
        defines["PADDED_CUTOFF"] = context.doubleToString(paddedCutoff);
        defines["PADDED_CUTOFF_SQUARED"] = context.doubleToString(paddedCutoff*paddedCutoff);
        defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(exclusionTiles->getSize());
        defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
        defines["SIMD_WIDTH"] = context.intToString(context.getSIMDWidth());
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
        defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
        defines["BUFFER_GROUPS"] = (deviceIsCpu ? "4" : "2");
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        int groupSize = (deviceIsCpu || context.getSIMDWidth() < 32 ? 32 : 256);
        while (true) {
            defines["GROUP_SIZE"] = context.intToString(groupSize);
            cl::Program interactingBlocksProgram = context.createProgram(file, defines);
            kernels.findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
            kernels.findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(6, context.getPosq().getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(7, blockCenter->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(8, blockBoundingBox->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(9, rebuildNeighborList->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(10, sortedBlocks->getDeviceBuffer());
            kernels.sortBoxDataKernel = cl::Kernel(interactingBlocksProgram, "sortBoxData");
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(0, sortedBlocks->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(1, blockCenter->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(2, blockBoundingBox->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(3, sortedBlockCenter->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(4, sortedBlockBoundingBox->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(6, oldPositions->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(7, interactionCount->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(8, rebuildNeighborList->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl_int>(9, true);
            kernels.findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactingAtoms->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(10, startBlockIndex);
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(11, numBlocks);
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(12, sortedBlocks->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(13, sortedBlockCenter->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(14, sortedBlockBoundingBox->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(15, exclusionIndices->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(16, exclusionRowIndices->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(17, oldPositions->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(18, rebuildNeighborList->getDeviceBuffer());
            if (kernels.findInteractingBlocksKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()) < groupSize) {
                // The device can't handle this block size, so reduce it.
556

557
558
559
560
561
562
563
564
565
566
567
568
                groupSize -= 32;
                if (groupSize < 32)
                    throw OpenMMException("Failed to create findInteractingBlocks kernel");
                continue;
            }
            break;
        }
        interactingBlocksThreadBlockSize = (deviceIsCpu ? 1 : groupSize);
    }
    groupKernels[groups] = kernels;
}

569
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy) {
570
571
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
572
    const string suffixes[] = {"x", "y", "z", "w"};
573
    stringstream localData;
574
    int localDataSize = 0;
575
    for (int i = 0; i < (int) params.size(); i++) {
576
577
578
579
580
581
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
582
583
584
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
585
    stringstream args;
586
    for (int i = 0; i < (int) params.size(); i++) {
587
        args << ", __global const ";
588
        args << params[i].getType();
589
        args << "* restrict global_";
590
591
        args << params[i].getName();
    }
592
    for (int i = 0; i < (int) arguments.size(); i++) {
593
594
595
596
597
598
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
599
                args << ", __global const ";
600
601
602
            else
                args << ", __constant ";
            args << arguments[i].getType();
603
            args << "* restrict ";
604
605
            args << arguments[i].getName();
        }
606
    }
607
608
    if (energyParameterDerivatives.size() > 0)
        args << ", __global mixed* energyParamDerivs";
609
610
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
611
    for (int i = 0; i < (int) params.size(); i++) {
612
        if (params[i].getNumComponents() == 1) {
613
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
614
615
616
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
617
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
618
        }
619
620
621
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
622
    for (int i = 0; i < (int) params.size(); i++) {
623
        if (params[i].getNumComponents() == 1) {
624
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
625
626
627
628
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
629
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
630
        }
631
632
633
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
634
    for (int i = 0; i < (int) params.size(); i++) {
635
636
637
638
639
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
640
        load1 << "[atom1];\n";
641
642
643
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
644
    for (int i = 0; i < (int) params.size(); i++) {
645
        if (params[i].getNumComponents() == 1) {
646
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
647
648
649
650
651
652
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
653
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
654
655
656
            }
            load2j<<");\n";
        }
657
    }
658
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
659
660
661
662
663
664
665
666
667
668
669
670
    stringstream initDerivs;
    for (int i = 0; i < energyParameterDerivatives.size(); i++)
        initDerivs<<"mixed energyParamDeriv"<<i<<" = 0;\n";
    replacements["INIT_DERIVATIVES"] = initDerivs.str();
    stringstream saveDerivs;
    const vector<string>& allParamDerivNames = context.getEnergyParamDerivNames();
    int numDerivs = allParamDerivNames.size();
    for (int i = 0; i < energyParameterDerivatives.size(); i++)
        for (int index = 0; index < numDerivs; index++)
            if (allParamDerivNames[index] == energyParameterDerivatives[i])
                saveDerivs<<"energyParamDerivs[get_global_id(0)*"<<numDerivs<<"+"<<i<<"] += energyParamDeriv"<<i<<";\n";
    replacements["SAVE_DERIVATIVES"] = saveDerivs.str();
671
672
673
674
675
676
677
    map<string, string> defines;
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
678
679
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
680
681
    if (useCutoff && context.getSIMDWidth() < 32)
        defines["PRUNE_BY_CUTOFF"] = "1";
682
683
684
685
    if (includeForces)
        defines["INCLUDE_FORCES"] = "1";
    if (includeEnergy)
        defines["INCLUDE_ENERGY"] = "1";
686
    defines["FORCE_WORK_GROUP_SIZE"] = context.intToString(forceThreadBlockSize);
687
688
689
690
691
692
693
694
695
696
    double maxCutoff = 0.0;
    for (int i = 0; i < 32; i++) {
        if ((groups&(1<<i)) != 0) {
            double cutoff = groupCutoff[i];
            maxCutoff = max(maxCutoff, cutoff);
            defines["CUTOFF_"+context.intToString(i)+"_SQUARED"] = context.doubleToString(cutoff*cutoff);
            defines["CUTOFF_"+context.intToString(i)] = context.doubleToString(cutoff);
        }
    }
    defines["MAX_CUTOFF"] = context.doubleToString(maxCutoff);
697
698
699
    defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
    defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
700
701
702
703
704
705
706
707
    defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
    int numExclusionTiles = exclusionTiles->getSize();
    defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(numExclusionTiles);
    int numContexts = context.getPlatformData().contexts.size();
    int startExclusionIndex = context.getContextIndex()*numExclusionTiles/numContexts;
    int endExclusionIndex = (context.getContextIndex()+1)*numExclusionTiles/numContexts;
    defines["FIRST_EXCLUSION_TILE"] = context.intToString(startExclusionIndex);
    defines["LAST_EXCLUSION_TILE"] = context.intToString(endExclusionIndex);
708
709
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
710
711
712
713
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else
714
        file = OpenCLKernelSources::nonbonded;
715
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
716
717
718
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
719

720
    int index = 0;
721
722
723
724
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
725
726
727
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
728
    kernel.setArg<cl::Buffer>(index++, exclusionTiles->getDeviceBuffer());
729
    kernel.setArg<cl_uint>(index++, startTileIndex);
730
    kernel.setArg<cl_uint>(index++, numTiles);
731
    if (useCutoff) {
732
733
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
734
        index += 5; // The periodic box size arguments are set when the kernel is executed.
735
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
736
        kernel.setArg<cl::Buffer>(index++, blockCenter->getDeviceBuffer());
737
        kernel.setArg<cl::Buffer>(index++, blockBoundingBox->getDeviceBuffer());
738
        kernel.setArg<cl::Buffer>(index++, interactingAtoms->getDeviceBuffer());
739
    }
740
    for (int i = 0; i < (int) params.size(); i++) {
741
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
742
    }
743
    for (int i = 0; i < (int) arguments.size(); i++) {
744
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
745
    }
746
747
    if (energyParameterDerivatives.size() > 0)
        kernel.setArg<cl::Memory>(index++, context.getEnergyParamDerivBuffer().getDeviceBuffer());
748
    return kernel;
749
}