OpenCLNonbondedUtilities.cpp 38.4 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2016 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/OpenMMException.h"
28
29
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
30
#include "OpenCLKernelSources.h"
31
#include "OpenCLExpressionUtilities.h"
32
33
#include "OpenCLSort.h"
#include <algorithm>
34
#include <map>
35
36
#include <set>
#include <utility>
37
38
39
40

using namespace OpenMM;
using namespace std;

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class OpenCLNonbondedUtilities::BlockSortTrait : public OpenCLSort::SortTrait {
public:
    BlockSortTrait(bool useDouble) : useDouble(useDouble) {
    }
    int getDataSize() const {return useDouble ? sizeof(mm_double2) : sizeof(mm_float2);}
    int getKeySize() const {return useDouble ? sizeof(cl_double) : sizeof(cl_float);}
    const char* getDataType() const {return "real2";}
    const char* getKeyType() const {return "real";}
    const char* getMinKey() const {return "-MAXFLOAT";}
    const char* getMaxKey() const {return "MAXFLOAT";}
    const char* getMaxValue() const {return "(real2) (MAXFLOAT, MAXFLOAT)";}
    const char* getSortKey() const {return "value.x";}
private:
    bool useDouble;
};

57
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), useCutoff(false), usePeriodic(false), anyExclusions(false), usePadding(true),
58
59
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusionTiles(NULL), exclusions(NULL), interactingTiles(NULL), interactingAtoms(NULL),
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), sortedBlocks(NULL), sortedBlockCenter(NULL), sortedBlockBoundingBox(NULL),
60
        oldPositions(NULL), rebuildNeighborList(NULL), blockSorter(NULL), pinnedCountBuffer(NULL), pinnedCountMemory(NULL), forceRebuildNeighborList(true), lastCutoff(0.0), groupFlags(0) {
61
    // Decide how many thread blocks and force buffers to use.
62

63
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
64
65
66
67
68
69
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
70
        if (context.getSupports64BitGlobalAtomics()) {
71
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
72
            forceThreadBlockSize = 256;
73
74
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
75
76
        }
        else {
77
78
79
            numForceThreadBlocks = 3*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 256;
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
80
        }
81
    }
82
    else {
83
        numForceThreadBlocks = context.getNumThreadBlocks();
84
        forceThreadBlockSize = (context.getSIMDWidth() >= 32 ? OpenCLContext::ThreadBlockSize : 32);
85
86
87
88
89
        if (context.getSupports64BitGlobalAtomics()) {
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
        }
        else {
90
            numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
91
        }
92
    }
93
94
    pinnedCountBuffer = new cl::Buffer(context.getContext(), CL_MEM_ALLOC_HOST_PTR, sizeof(int));
    pinnedCountMemory = (int*) context.getQueue().enqueueMapBuffer(*pinnedCountBuffer, CL_TRUE, CL_MAP_READ, 0, sizeof(int));
95
96
97
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
98
99
100
101
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
102
103
    if (exclusionTiles != NULL)
        delete exclusionTiles;
104
105
    if (exclusions != NULL)
        delete exclusions;
106
107
    if (interactingTiles != NULL)
        delete interactingTiles;
108
109
    if (interactingAtoms != NULL)
        delete interactingAtoms;
110
111
112
113
114
115
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
116
117
118
119
120
121
122
123
124
125
126
127
    if (sortedBlocks != NULL)
        delete sortedBlocks;
    if (sortedBlockCenter != NULL)
        delete sortedBlockCenter;
    if (sortedBlockBoundingBox != NULL)
        delete sortedBlockBoundingBox;
    if (oldPositions != NULL)
        delete oldPositions;
    if (rebuildNeighborList != NULL)
        delete rebuildNeighborList;
    if (blockSorter != NULL)
        delete blockSorter;
128
129
    if (pinnedCountBuffer != NULL)
        delete pinnedCountBuffer;
130
131
}

132
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
133
    if (groupCutoff.size() > 0) {
134
135
136
137
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
138
139
        if (usesCutoff && groupCutoff.find(forceGroup) != groupCutoff.end() && groupCutoff[forceGroup] != cutoffDistance)
            throw OpenMMException("All Forces in a single force group must use the same cutoff distance");
140
    }
141
142
143
144
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
145
146
147
148
149
150
151
152
153
154
    groupCutoff[forceGroup] = cutoffDistance;
    groupFlags |= 1<<forceGroup;
    if (kernel.size() > 0) {
        if (groupKernelSource.find(forceGroup) == groupKernelSource.end())
            groupKernelSource[forceGroup] = "";
        map<string, string> replacements;
        replacements["CUTOFF"] = "CUTOFF_"+context.intToString(forceGroup);
        replacements["CUTOFF_SQUARED"] = "CUTOFF_"+context.intToString(forceGroup)+"_SQUARED";
        groupKernelSource[forceGroup] += context.replaceStrings(kernel, replacements)+"\n";
    }
155
156
157
158
159
160
161
162
163
164
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

165
166
167
168
169
170
171
172
173
174
175
176
177
string OpenCLNonbondedUtilities::addEnergyParameterDerivative(const string& param) {
    // See if the parameter has already been added.
    
    int index;
    for (index = 0; index < energyParameterDerivatives.size(); index++)
        if (param == energyParameterDerivatives[index])
            break;
    if (index == energyParameterDerivatives.size())
        energyParameterDerivatives.push_back(param);
    context.addEnergyParameterDerivative(param);
    return string("energyParamDeriv")+context.intToString(index);
}

178
179
void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
180
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
181
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
182
183
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
184
185
            set<int> expectedExclusions;
            expectedExclusions.insert(atomExclusions[i].begin(), atomExclusions[i].end());
186
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
187
                if (expectedExclusions.find(exclusionList[i][j]) == expectedExclusions.end())
188
189
190
191
192
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
193
    else {
194
        atomExclusions = exclusionList;
195
196
        anyExclusions = true;
    }
197
198
}

199
static bool compareUshort2(mm_ushort2 a, mm_ushort2 b) {
peastman's avatar
peastman committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    // This version is used on devices with SIMD width of 32 or less.  It sorts tiles to improve cache efficiency.

    return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}

static bool compareUshort2LargeSIMD(mm_ushort2 a, mm_ushort2 b) {
    // This version is used on devices with SIMD width greater than 32.  It puts diagonal tiles before off-diagonal
    // ones to reduce thread divergence.
    
    if (a.x == a.y) {
        if (b.x == b.y)
            return (a.x < b.x);
        return true;
    }
    if (b.x == b.y)
        return false;
216
217
218
    return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}

219
void OpenCLNonbondedUtilities::initialize(const System& system) {
220
221
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
222

223
        atomExclusions.resize(context.getNumAtoms());
224
        for (int i = 0; i < (int) atomExclusions.size(); i++)
225
226
227
            atomExclusions[i].push_back(i);
    }

228
229
230
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
231
    int numContexts = context.getPlatformData().contexts.size();
232
    setAtomBlockRange(context.getContextIndex()/(double) numContexts, (context.getContextIndex()+1)/(double) numContexts);
233

234
    // Build a list of tiles that contain exclusions.
235

236
237
238
239
240
241
242
243
244
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
245
246
247
    vector<mm_ushort2> exclusionTilesVec;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter)
        exclusionTilesVec.push_back(mm_ushort2((unsigned short) iter->first, (unsigned short) iter->second));
peastman's avatar
peastman committed
248
    sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), context.getSIMDWidth() <= 32 ? compareUshort2 : compareUshort2LargeSIMD);
249
250
251
252
253
254
255
256
257
258
259
260
    exclusionTiles = OpenCLArray::create<mm_ushort2>(context, exclusionTilesVec.size(), "exclusionTiles");
    exclusionTiles->upload(exclusionTilesVec);
    map<pair<int, int>, int> exclusionTileMap;
    for (int i = 0; i < (int) exclusionTilesVec.size(); i++) {
        mm_ushort2 tile = exclusionTilesVec[i];
        exclusionTileMap[make_pair(tile.x, tile.y)] = i;
    }
    vector<vector<int> > exclusionBlocksForBlock(numAtomBlocks);
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        exclusionBlocksForBlock[iter->first].push_back(iter->second);
        if (iter->first != iter->second)
            exclusionBlocksForBlock[iter->second].push_back(iter->first);
261
262
263
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
264
265
266
    for (int i = 0; i < numAtomBlocks; i++) {
        exclusionIndicesVec.insert(exclusionIndicesVec.end(), exclusionBlocksForBlock[i].begin(), exclusionBlocksForBlock[i].end());
        exclusionRowIndicesVec[i+1] = exclusionIndicesVec.size();
267
    }
268
269
270
    maxExclusions = 0;
    for (int i = 0; i < (int) exclusionBlocksForBlock.size(); i++)
        maxExclusions = (maxExclusions > exclusionBlocksForBlock[i].size() ? maxExclusions : exclusionBlocksForBlock[i].size());
271
272
    exclusionIndices = OpenCLArray::create<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = OpenCLArray::create<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
273
274
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
275
276
277

    // Record the exclusion data.

278
    exclusions = OpenCLArray::create<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
279
280
    cl_uint allFlags = (cl_uint) -1;
    vector<cl_uint> exclusionVec(exclusions->getSize(), allFlags);
281
282
283
284
285
286
287
288
289
290
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
291
292
                int index = exclusionTileMap[make_pair(x, y)]*OpenCLContext::TileSize;
                exclusionVec[index+offset1] &= allFlags-(1<<offset2);
293
294
            }
            else {
295
296
                int index = exclusionTileMap[make_pair(y, x)]*OpenCLContext::TileSize;
                exclusionVec[index+offset2] &= allFlags-(1<<offset1);
297
298
299
300
301
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
302
303
304
305

    // Create data structures for the neighbor list.

    if (useCutoff) {
306
307
308
309
310
311
312
313
314
        // Select a size for the arrays that hold the neighbor list.  We have to make a fairly
        // arbitrary guess, but if this turns out to be too small we'll increase it later.

        int maxTiles = 20*numAtomBlocks;
        if (maxTiles > numTiles)
            maxTiles = numTiles;
        if (maxTiles < 1)
            maxTiles = 1;
        int numAtoms = context.getNumAtoms();
315
        interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
316
        interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
317
        interactionCount = OpenCLArray::create<cl_uint>(context, 1, "interactionCount");
318
319
320
321
322
323
324
325
326
        int elementSize = (context.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
        blockCenter = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockCenter");
        blockBoundingBox = new OpenCLArray(context, numAtomBlocks, 4*elementSize, "blockBoundingBox");
        sortedBlocks = new OpenCLArray(context, numAtomBlocks, 2*elementSize, "sortedBlocks");
        sortedBlockCenter = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockCenter");
        sortedBlockBoundingBox = new OpenCLArray(context, numAtomBlocks+1, 4*elementSize, "sortedBlockBoundingBox");
        oldPositions = new OpenCLArray(context, numAtoms, 4*elementSize, "oldPositions");
        rebuildNeighborList = OpenCLArray::create<int>(context, 1, "rebuildNeighborList");
        blockSorter = new OpenCLSort(context, new BlockSortTrait(context.getUseDoublePrecision()), numAtomBlocks);
327
328
        vector<cl_uint> count(1, 0);
        interactionCount->upload(count);
329
    }
330
331
}

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
static void setPeriodicBoxArgs(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision()) {
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getInvPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecXDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecYDouble());
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxVecZDouble());
    }
    else {
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getInvPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecX());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecY());
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxVecZ());
    }
347
348
}

349
350
351
352
353
354
355
356
357
358
359
360
double OpenCLNonbondedUtilities::getMaxCutoffDistance() {
    double cutoff = 0.0;
    for (map<int, double>::const_iterator iter = groupCutoff.begin(); iter != groupCutoff.end(); ++iter)
        cutoff = max(cutoff, iter->second);
    return cutoff;
}

void OpenCLNonbondedUtilities::prepareInteractions(int forceGroups) {
    if ((forceGroups&groupFlags) == 0)
        return;
    if (groupKernels.find(forceGroups) == groupKernels.end())
        createKernelsForGroups(forceGroups);
361
362
    if (!useCutoff)
        return;
363
364
    if (numTiles == 0)
        return;
365
    KernelSet& kernels = groupKernels[forceGroups];
366
367
    if (usePeriodic) {
        mm_float4 box = context.getPeriodicBoxSize();
368
        double minAllowedSize = 1.999999*kernels.cutoffDistance;
369
370
371
        if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
    }
372
373
374

    // Compute the neighbor list.

375
376
    if (lastCutoff != kernels.cutoffDistance)
        forceRebuildNeighborList = true;
377
378
379
380
381
382
383
384
    setPeriodicBoxArgs(context, kernels.findBlockBoundsKernel, 1);
    context.executeKernel(kernels.findBlockBoundsKernel, context.getNumAtoms());
    blockSorter->sort(*sortedBlocks);
    kernels.sortBoxDataKernel.setArg<cl_int>(9, forceRebuildNeighborList);
    context.executeKernel(kernels.sortBoxDataKernel, context.getNumAtoms());
    setPeriodicBoxArgs(context, kernels.findInteractingBlocksKernel, 0);
    context.executeKernel(kernels.findInteractingBlocksKernel, context.getNumAtoms(), interactingBlocksThreadBlockSize);
    forceRebuildNeighborList = false;
385
    lastCutoff = kernels.cutoffDistance;
386
    context.getQueue().enqueueReadBuffer(interactionCount->getDeviceBuffer(), CL_FALSE, 0, sizeof(int), pinnedCountMemory, NULL, &downloadCountEvent); 
387
388
}

389
void OpenCLNonbondedUtilities::computeInteractions(int forceGroups, bool includeForces, bool includeEnergy) {
390
391
392
393
    if ((forceGroups&groupFlags) == 0)
        return;
    KernelSet& kernels = groupKernels[forceGroups];
    if (kernels.hasForces) {
394
395
396
        cl::Kernel& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel);
        if (*reinterpret_cast<cl_kernel*>(&kernel) == NULL)
            kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy);
397
        if (useCutoff)
398
399
            setPeriodicBoxArgs(context, kernel, 9);
        context.executeKernel(kernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
400
    }
401
402
403
404
    if (useCutoff && numTiles > 0) {
        downloadCountEvent.wait();
        updateNeighborListSize();
    }
405
406
}

407
bool OpenCLNonbondedUtilities::updateNeighborListSize() {
408
    if (!useCutoff)
409
        return false;
410
    if (pinnedCountMemory[0] <= (unsigned int) interactingTiles->getSize())
411
        return false;
412
413
414
415

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

416
    int maxTiles = (int) (1.2*pinnedCountMemory[0]);
417
418
419
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (maxTiles > totalTiles)
        maxTiles = totalTiles;
420
    delete interactingTiles;
421
422
423
    delete interactingAtoms;
    interactingTiles = NULL; // Avoid an error in the destructor if the following allocation fails
    interactingAtoms = NULL;
424
    interactingTiles = OpenCLArray::create<cl_int>(context, maxTiles, "interactingTiles");
425
    interactingAtoms = OpenCLArray::create<cl_int>(context, OpenCLContext::TileSize*maxTiles, "interactingAtoms");
426
    for (map<int, KernelSet>::iterator iter = groupKernels.begin(); iter != groupKernels.end(); ++iter) {
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        KernelSet& kernels = iter->second;
        if (*reinterpret_cast<cl_kernel*>(&kernels.forceKernel) != NULL) {
            kernels.forceKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
            kernels.forceKernel.setArg<cl_uint>(14, maxTiles);
            kernels.forceKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        }
        if (*reinterpret_cast<cl_kernel*>(&kernels.energyKernel) != NULL) {
            kernels.energyKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
            kernels.energyKernel.setArg<cl_uint>(14, maxTiles);
            kernels.energyKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        }
        if (*reinterpret_cast<cl_kernel*>(&kernels.forceEnergyKernel) != NULL) {
            kernels.forceEnergyKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
            kernels.forceEnergyKernel.setArg<cl_uint>(14, maxTiles);
            kernels.forceEnergyKernel.setArg<cl::Buffer>(17, interactingAtoms->getDeviceBuffer());
        }
        kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
        kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactingAtoms->getDeviceBuffer());
        kernels.findInteractingBlocksKernel.setArg<cl_uint>(9, maxTiles);
446
447
    }
    forceRebuildNeighborList = true;
448
    context.setForcesValid(false);
449
    return true;
450
451
}

452
453
454
455
456
457
458
459
460
461
462
void OpenCLNonbondedUtilities::setUsePadding(bool padding) {
    usePadding = padding;
}

void OpenCLNonbondedUtilities::setAtomBlockRange(double startFraction, double endFraction) {
    int numAtomBlocks = context.getNumAtomBlocks();
    startBlockIndex = (int) (startFraction*numAtomBlocks);
    numBlocks = (int) (endFraction*numAtomBlocks)-startBlockIndex;
    int totalTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    startTileIndex = (int) (startFraction*totalTiles);;
    numTiles = (int) (endFraction*totalTiles)-startTileIndex;
463
    if (useCutoff) {
464
        // We are using a cutoff, and the kernels have already been created.
465

466
        for (map<int, KernelSet>::iterator iter = groupKernels.begin(); iter != groupKernels.end(); ++iter) {
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
            KernelSet& kernels = iter->second;
            if (*reinterpret_cast<cl_kernel*>(&kernels.forceKernel) != NULL) {
                kernels.forceKernel.setArg<cl_uint>(5, startTileIndex);
                kernels.forceKernel.setArg<cl_uint>(6, numTiles);
            }
            if (*reinterpret_cast<cl_kernel*>(&kernels.energyKernel) != NULL) {
                kernels.energyKernel.setArg<cl_uint>(5, startTileIndex);
                kernels.energyKernel.setArg<cl_uint>(6, numTiles);
            }
            if (*reinterpret_cast<cl_kernel*>(&kernels.forceEnergyKernel) != NULL) {
                kernels.forceEnergyKernel.setArg<cl_uint>(5, startTileIndex);
                kernels.forceEnergyKernel.setArg<cl_uint>(6, numTiles);
            }
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(10, startBlockIndex);
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(11, numBlocks);
482
483
        }
        forceRebuildNeighborList = true;
484
485
486
    }
}

487
488
489
490
491
492
493
494
495
496
497
498
void OpenCLNonbondedUtilities::createKernelsForGroups(int groups) {
    KernelSet kernels;
    double cutoff = 0.0;
    string source;
    for (int i = 0; i < 32; i++) {
        if ((groups&(1<<i)) != 0) {
            cutoff = max(cutoff, groupCutoff[i]);
            source += groupKernelSource[i];
        }
    }
    kernels.hasForces = (source.size() > 0);
    kernels.cutoffDistance = cutoff;
499
    kernels.source = source;
500
    if (useCutoff) {
501
502
503
504
505
506
507
508
509
510
511
512
513
        double padding = (usePadding ? 0.1*cutoff : 0.0);
        double paddedCutoff = cutoff+padding;
        map<string, string> defines;
        defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
        defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
        defines["PADDING"] = context.doubleToString(padding);
        defines["PADDED_CUTOFF"] = context.doubleToString(paddedCutoff);
        defines["PADDED_CUTOFF_SQUARED"] = context.doubleToString(paddedCutoff*paddedCutoff);
        defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(exclusionTiles->getSize());
        defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
        defines["SIMD_WIDTH"] = context.intToString(context.getSIMDWidth());
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
514
515
        if (context.getBoxIsTriclinic())
            defines["TRICLINIC"] = "1";
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
        defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
        defines["BUFFER_GROUPS"] = (deviceIsCpu ? "4" : "2");
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        int groupSize = (deviceIsCpu || context.getSIMDWidth() < 32 ? 32 : 256);
        while (true) {
            defines["GROUP_SIZE"] = context.intToString(groupSize);
            cl::Program interactingBlocksProgram = context.createProgram(file, defines);
            kernels.findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
            kernels.findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(6, context.getPosq().getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(7, blockCenter->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(8, blockBoundingBox->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(9, rebuildNeighborList->getDeviceBuffer());
            kernels.findBlockBoundsKernel.setArg<cl::Buffer>(10, sortedBlocks->getDeviceBuffer());
            kernels.sortBoxDataKernel = cl::Kernel(interactingBlocksProgram, "sortBoxData");
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(0, sortedBlocks->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(1, blockCenter->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(2, blockBoundingBox->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(3, sortedBlockCenter->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(4, sortedBlockBoundingBox->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(6, oldPositions->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(7, interactionCount->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl::Buffer>(8, rebuildNeighborList->getDeviceBuffer());
            kernels.sortBoxDataKernel.setArg<cl_int>(9, true);
            kernels.findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactingAtoms->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(10, startBlockIndex);
            kernels.findInteractingBlocksKernel.setArg<cl_uint>(11, numBlocks);
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(12, sortedBlocks->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(13, sortedBlockCenter->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(14, sortedBlockBoundingBox->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(15, exclusionIndices->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(16, exclusionRowIndices->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(17, oldPositions->getDeviceBuffer());
            kernels.findInteractingBlocksKernel.setArg<cl::Buffer>(18, rebuildNeighborList->getDeviceBuffer());
            if (kernels.findInteractingBlocksKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()) < groupSize) {
                // The device can't handle this block size, so reduce it.
558

559
560
561
562
563
564
565
566
567
568
569
570
                groupSize -= 32;
                if (groupSize < 32)
                    throw OpenMMException("Failed to create findInteractingBlocks kernel");
                continue;
            }
            break;
        }
        interactingBlocksThreadBlockSize = (deviceIsCpu ? 1 : groupSize);
    }
    groupKernels[groups] = kernels;
}

571
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy) {
572
573
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
574
    const string suffixes[] = {"x", "y", "z", "w"};
575
    stringstream localData;
576
    int localDataSize = 0;
577
    for (int i = 0; i < (int) params.size(); i++) {
578
579
580
581
582
583
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
584
585
586
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
587
    stringstream args;
588
    for (int i = 0; i < (int) params.size(); i++) {
589
        args << ", __global const ";
590
        args << params[i].getType();
591
        args << "* restrict global_";
592
593
        args << params[i].getName();
    }
594
    for (int i = 0; i < (int) arguments.size(); i++) {
595
596
597
598
599
600
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
601
                args << ", __global const ";
602
603
604
            else
                args << ", __constant ";
            args << arguments[i].getType();
605
            args << "* restrict ";
606
607
            args << arguments[i].getName();
        }
608
    }
609
    if (energyParameterDerivatives.size() > 0)
610
        args << ", __global mixed* restrict energyParamDerivs";
611
612
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
613
    for (int i = 0; i < (int) params.size(); i++) {
614
        if (params[i].getNumComponents() == 1) {
615
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
616
617
618
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
619
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
620
        }
621
622
623
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
624
    for (int i = 0; i < (int) params.size(); i++) {
625
        if (params[i].getNumComponents() == 1) {
626
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
627
628
629
630
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
631
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
632
        }
633
634
635
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
636
    for (int i = 0; i < (int) params.size(); i++) {
637
638
639
640
641
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
642
        load1 << "[atom1];\n";
643
644
645
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
646
    for (int i = 0; i < (int) params.size(); i++) {
647
        if (params[i].getNumComponents() == 1) {
648
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
649
650
651
652
653
654
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
655
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
656
657
658
            }
            load2j<<");\n";
        }
659
    }
660
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
661
662
663
664
665
666
667
668
669
670
    stringstream initDerivs;
    for (int i = 0; i < energyParameterDerivatives.size(); i++)
        initDerivs<<"mixed energyParamDeriv"<<i<<" = 0;\n";
    replacements["INIT_DERIVATIVES"] = initDerivs.str();
    stringstream saveDerivs;
    const vector<string>& allParamDerivNames = context.getEnergyParamDerivNames();
    int numDerivs = allParamDerivNames.size();
    for (int i = 0; i < energyParameterDerivatives.size(); i++)
        for (int index = 0; index < numDerivs; index++)
            if (allParamDerivNames[index] == energyParameterDerivatives[i])
671
                saveDerivs<<"energyParamDerivs[get_global_id(0)*"<<numDerivs<<"+"<<index<<"] += energyParamDeriv"<<i<<";\n";
672
    replacements["SAVE_DERIVATIVES"] = saveDerivs.str();
673
674
675
676
677
678
679
    map<string, string> defines;
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
680
681
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
682
683
    if (useCutoff && context.getSIMDWidth() < 32)
        defines["PRUNE_BY_CUTOFF"] = "1";
684
685
686
687
    if (includeForces)
        defines["INCLUDE_FORCES"] = "1";
    if (includeEnergy)
        defines["INCLUDE_ENERGY"] = "1";
688
    defines["FORCE_WORK_GROUP_SIZE"] = context.intToString(forceThreadBlockSize);
689
690
691
692
693
694
695
696
697
698
    double maxCutoff = 0.0;
    for (int i = 0; i < 32; i++) {
        if ((groups&(1<<i)) != 0) {
            double cutoff = groupCutoff[i];
            maxCutoff = max(maxCutoff, cutoff);
            defines["CUTOFF_"+context.intToString(i)+"_SQUARED"] = context.doubleToString(cutoff*cutoff);
            defines["CUTOFF_"+context.intToString(i)] = context.doubleToString(cutoff);
        }
    }
    defines["MAX_CUTOFF"] = context.doubleToString(maxCutoff);
699
700
701
    defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
    defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
702
703
704
705
706
707
708
709
    defines["TILE_SIZE"] = context.intToString(OpenCLContext::TileSize);
    int numExclusionTiles = exclusionTiles->getSize();
    defines["NUM_TILES_WITH_EXCLUSIONS"] = context.intToString(numExclusionTiles);
    int numContexts = context.getPlatformData().contexts.size();
    int startExclusionIndex = context.getContextIndex()*numExclusionTiles/numContexts;
    int endExclusionIndex = (context.getContextIndex()+1)*numExclusionTiles/numContexts;
    defines["FIRST_EXCLUSION_TILE"] = context.intToString(startExclusionIndex);
    defines["LAST_EXCLUSION_TILE"] = context.intToString(endExclusionIndex);
710
711
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
712
713
714
715
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else
716
        file = OpenCLKernelSources::nonbonded;
717
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
718
719
720
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
721

722
    int index = 0;
723
724
725
726
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
727
728
729
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
730
    kernel.setArg<cl::Buffer>(index++, exclusionTiles->getDeviceBuffer());
731
    kernel.setArg<cl_uint>(index++, startTileIndex);
732
    kernel.setArg<cl_uint>(index++, numTiles);
733
    if (useCutoff) {
734
735
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
736
        index += 5; // The periodic box size arguments are set when the kernel is executed.
737
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
738
        kernel.setArg<cl::Buffer>(index++, blockCenter->getDeviceBuffer());
739
        kernel.setArg<cl::Buffer>(index++, blockBoundingBox->getDeviceBuffer());
740
        kernel.setArg<cl::Buffer>(index++, interactingAtoms->getDeviceBuffer());
741
    }
742
    for (int i = 0; i < (int) params.size(); i++) {
743
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
744
    }
745
    for (int i = 0; i < (int) arguments.size(); i++) {
746
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
747
    }
748
749
    if (energyParameterDerivatives.size() > 0)
        kernel.setArg<cl::Memory>(index++, context.getEnergyParamDerivBuffer().getDeviceBuffer());
750
    return kernel;
751
}