OpenCLNonbondedUtilities.cpp 28 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2012 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/OpenMMException.h"
28
29
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
30
#include "OpenCLKernelSources.h"
31
#include "OpenCLExpressionUtilities.h"
32
#include <map>
33
34
#include <set>
#include <utility>
35
36
37
38

using namespace OpenMM;
using namespace std;

39
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), cutoff(-1.0), useCutoff(false), anyExclusions(false),
40
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusions(NULL), interactingTiles(NULL), interactionFlags(NULL),
41
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), nonbondedForceGroup(0) {
42
    // Decide how many thread blocks and force buffers to use.
43

44
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
45
    forceBufferPerAtomBlock = false;
46
47
48
49
50
51
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
52
        if (context.getSupports64BitGlobalAtomics()) {
53
            numForceThreadBlocks = 2*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
54
            forceThreadBlockSize = 256;
55
56
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
57
58
59
60
61
62
        }
        else {
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 128;
            numForceBuffers = numForceThreadBlocks;
        }
63
    }
64
    else {
65
66
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = OpenCLContext::ThreadBlockSize;
67
68
69
70
71
72
73
74
        if (context.getSupports64BitGlobalAtomics()) {
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
        }
        else {
            numForceBuffers = numForceThreadBlocks;
            if (numForceBuffers >= context.getNumAtomBlocks()) {
                // For small systems, it is more efficient to have one force buffer per block of 32 atoms instead of one per warp.
75

76
77
78
                forceBufferPerAtomBlock = true;
                numForceBuffers = context.getNumAtomBlocks();
            }
79
        }
80
    }
81
82
83
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
84
85
86
87
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
88
89
    if (exclusions != NULL)
        delete exclusions;
90
91
92
93
94
95
96
97
98
99
    if (interactingTiles != NULL)
        delete interactingTiles;
    if (interactionFlags != NULL)
        delete interactionFlags;
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
100
101
}

102
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
103
104
105
106
107
108
109
    if (cutoff != -1.0) {
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
        if (cutoffDistance != cutoff)
            throw OpenMMException("All Forces must use the same cutoff distance");
110
111
        if (forceGroup != nonbondedForceGroup)
            throw OpenMMException("All nonbonded forces must be in the same force group");
112
    }
113
114
115
116
117
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
    cutoff = cutoffDistance;
Peter Eastman's avatar
Peter Eastman committed
118
119
    if (kernel.size() > 0)
        kernelSource += kernel+"\n";
120
    nonbondedForceGroup = forceGroup;
121
122
123
124
125
126
127
128
129
130
131
132
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
133
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
134
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
135
136
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
137
138
            set<int> expectedExclusions;
            expectedExclusions.insert(atomExclusions[i].begin(), atomExclusions[i].end());
139
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
140
                if (expectedExclusions.find(exclusionList[i][j]) == expectedExclusions.end())
141
142
143
144
145
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
146
    else {
147
        atomExclusions = exclusionList;
148
149
        anyExclusions = true;
    }
150
151
152
}

void OpenCLNonbondedUtilities::initialize(const System& system) {
153
154
155
156
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
        
        atomExclusions.resize(context.getNumAtoms());
157
        for (int i = 0; i < (int) atomExclusions.size(); i++)
158
159
160
            atomExclusions[i].push_back(i);
    }

161
162
163
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
164
165
166
167
168
    int totalTiles = numAtomBlocks*(numAtomBlocks+1)/2;
    int numContexts = context.getPlatformData().contexts.size();
    startTileIndex = context.getContextIndex()*totalTiles/numContexts;
    int endTileIndex = (context.getContextIndex()+1)*totalTiles/numContexts;
    numTiles = endTileIndex-startTileIndex;
169
170
171

    // Build a list of indices for the tiles with exclusions.

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
    if (context.getPaddedNumAtoms() > context.getNumAtoms()) {
        for (int i = 0; i < numAtomBlocks; ++i)
            tilesWithExclusions.insert(make_pair(numAtomBlocks-1, i));
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
    int currentRow = 0;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        while (iter->first != currentRow)
            exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
        exclusionIndicesVec.push_back(iter->second);
    }
    exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
194
195
    exclusionIndices = OpenCLArray::create<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = OpenCLArray::create<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
196
197
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
198
199
200

    // Record the exclusion data.

201
    exclusions = OpenCLArray::create<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
202
203
204
205
206
207
208
209
210
211
212
    vector<cl_uint> exclusionVec(exclusions->getSize());
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
213
214
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
215
216
            }
            else {
217
218
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
219
220
221
222
223
224
225
226
227
228
229
230
231
            }
        }
    }

    // Mark all interactions that involve a padding atom as being excluded.

    for (int atom1 = context.getNumAtoms(); atom1 < context.getPaddedNumAtoms(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int atom2 = 0; atom2 < context.getPaddedNumAtoms(); ++atom2) {
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x >= y) {
232
233
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
234
235
            }
            if (y >= x) {
236
237
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
238
239
240
241
242
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
243
244
245
246

    // Create data structures for the neighbor list.

    if (useCutoff) {
247
248
249
250
251
252
253
        // Select a size for the arrays that hold the neighbor list.  This estimate is intentionally very
        // high, because if it ever is too small, we have to fall back to the N^2 algorithm.

        mm_float4 boxSize = context.getPeriodicBoxSize();
        int maxInteractingTiles = (int) (numTiles*(cutoff/boxSize.x+cutoff/boxSize.y+cutoff/boxSize.z));
        if (maxInteractingTiles > numTiles)
            maxInteractingTiles = numTiles;
254
255
        if (maxInteractingTiles < 1)
            maxInteractingTiles = 1;
256
257
258
        interactingTiles = OpenCLArray::create<mm_ushort2>(context, maxInteractingTiles, "interactingTiles");
        interactionFlags = OpenCLArray::create<cl_uint>(context, context.getSIMDWidth() == 32 ? maxInteractingTiles : (deviceIsCpu ? 2*maxInteractingTiles : 1), "interactionFlags");
        interactionCount = OpenCLArray::create<cl_uint>(context, 1, "interactionCount");
259
260
261
        int elementSize = (context.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
        blockCenter = new OpenCLArray(context, numAtomBlocks, elementSize, "blockCenter");
        blockBoundingBox = new OpenCLArray(context, numAtomBlocks, elementSize, "blockBoundingBox");
262
263
        vector<cl_uint> count(1, 0);
        interactionCount->upload(count);
264
    }
265
266
267

    // Create kernels.

Peter Eastman's avatar
Peter Eastman committed
268
269
    if (kernelSource.size() > 0)
        forceKernel = createInteractionKernel(kernelSource, parameters, arguments, true, true);
270
271
    if (useCutoff) {
        map<string, string> defines;
272
        defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
273
274
275
276
        if (forceBufferPerAtomBlock)
            defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
277
278
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        cl::Program interactingBlocksProgram = context.createProgram(file, defines);
279
280
        findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
        findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
281
282
283
        findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer());
284
        findBlockBoundsKernel.setArg<cl::Buffer>(6, interactionCount->getDeviceBuffer());
285
        findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
286
287
288
289
        if (context.getUseDoublePrecision())
            findInteractingBlocksKernel.setArg<cl_double>(0, cutoff*cutoff);
        else
            findInteractingBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
290
291
292
293
        findInteractingBlocksKernel.setArg<cl::Buffer>(3, blockCenter->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(4, blockBoundingBox->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
294
295
296
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
297
298
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
299
        if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
300
            findInteractionsWithinBlocksKernel = cl::Kernel(interactingBlocksProgram, "findInteractionsWithinBlocks");
301
302
303
304
            if (context.getUseDoublePrecision())
                findInteractionsWithinBlocksKernel.setArg<cl_double>(0, cutoff*cutoff);
            else
                findInteractionsWithinBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
305
306
307
308
309
310
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(5, blockCenter->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(6, blockBoundingBox->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(8, interactionCount->getDeviceBuffer());
311
            findInteractionsWithinBlocksKernel.setArg(9, 128*sizeof(cl_uint), NULL);
312
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, interactingTiles->getSize());
313
        }
314
    }
315
316
}

317
318
319
320
321
322
323
324
325
int OpenCLNonbondedUtilities::findExclusionIndex(int x, int y, const vector<cl_uint>& exclusionIndices, const vector<cl_uint>& exclusionRowIndices) {
    int start = exclusionRowIndices[x];
    int end = exclusionRowIndices[x+1];
    for (int i = start; i < end; i++)
        if (exclusionIndices[i] == y)
            return i*OpenCLContext::TileSize;
    throw OpenMMException("Internal error: exclusion in unexpected tile");
}

326
327
328
329
330
331
332
333
334
335
336
337
338
339
static void setPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxSize());
}

static void setInvPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getInvPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getInvPeriodicBoxSize());
}

340
341
342
void OpenCLNonbondedUtilities::prepareInteractions() {
    if (!useCutoff)
        return;
343
344
    if (usePeriodic) {
        mm_float4 box = context.getPeriodicBoxSize();
345
        double minAllowedSize = 1.999999*cutoff;
346
347
348
        if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
    }
349
350
351

    // Compute the neighbor list.

352
353
    setPeriodicBoxSizeArg(context, findBlockBoundsKernel, 1);
    setInvPeriodicBoxSizeArg(context, findBlockBoundsKernel, 2);
354
    context.executeKernel(findBlockBoundsKernel, context.getNumAtoms());
355
356
    setPeriodicBoxSizeArg(context, findInteractingBlocksKernel, 1);
    setInvPeriodicBoxSizeArg(context, findInteractingBlocksKernel, 2);
357
358
    context.executeKernel(findInteractingBlocksKernel, context.getNumAtoms(), deviceIsCpu ? 1 : -1);
    if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
359
360
        setPeriodicBoxSizeArg(context, findInteractionsWithinBlocksKernel, 1);
        setInvPeriodicBoxSizeArg(context, findInteractionsWithinBlocksKernel, 2);
361
        context.executeKernel(findInteractionsWithinBlocksKernel, context.getNumAtoms(), 128);
362
    }
363
364
365
}

void OpenCLNonbondedUtilities::computeInteractions() {
Peter Eastman's avatar
Peter Eastman committed
366
    if (kernelSource.size() > 0) {
367
        if (useCutoff) {
368
369
            setPeriodicBoxSizeArg(context, forceKernel, 10);
            setInvPeriodicBoxSizeArg(context, forceKernel, 11);
370
        }
371
        context.executeKernel(forceKernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
372
    }
373
374
}

375
376
377
void OpenCLNonbondedUtilities::updateNeighborListSize() {
    if (!useCutoff)
        return;
378
379
380
    unsigned int* pinnedInteractionCount = (unsigned int*) context.getPinnedBuffer();
    interactionCount->download(pinnedInteractionCount);
    if (pinnedInteractionCount[0] <= (unsigned int) interactingTiles->getSize())
381
382
383
384
385
        return;

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

386
    int newSize = (int) (1.2*pinnedInteractionCount[0]);
387
388
389
390
    int numTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (newSize > numTiles)
        newSize = numTiles;
    delete interactingTiles;
391
    interactingTiles = OpenCLArray::create<mm_ushort2>(context, newSize, "interactingTiles");
392
393
    forceKernel.setArg<cl::Buffer>(8, interactingTiles->getDeviceBuffer());
    forceKernel.setArg<cl_uint>(12, newSize);
394
    findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
395
    findInteractingBlocksKernel.setArg<cl_uint>(9, newSize);
Peter Eastman's avatar
Peter Eastman committed
396
    if (context.getSIMDWidth() == 32 || deviceIsCpu) {
397
        delete interactionFlags;
398
        interactionFlags = OpenCLArray::create<cl_uint>(context, deviceIsCpu ? 2*newSize : newSize, "interactionFlags");
399
        forceKernel.setArg<cl::Buffer>(13, interactionFlags->getDeviceBuffer());
400
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
401
402
403
404
405
		if (!deviceIsCpu) {
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, newSize);
		}
406
407
408
    }
}

409
410
411
void OpenCLNonbondedUtilities::setTileRange(int startTileIndex, int numTiles) {
    this->startTileIndex = startTileIndex;
    this->numTiles = numTiles;
Peter Eastman's avatar
Peter Eastman committed
412
    if (kernelSource.size() == 0)
413
        return; // There are no nonbonded interactions in the System.
414
415
    forceKernel.setArg<cl_uint>(6, startTileIndex);
    forceKernel.setArg<cl_uint>(7, startTileIndex+numTiles);
416
417
418
419
    if (useCutoff) {
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
    }
420
    else
421
        forceKernel.setArg<cl_uint>(8, numTiles);
422
423
}

424
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) const {
425
426
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
427
    const string suffixes[] = {"x", "y", "z", "w"};
428
    stringstream localData;
429
    int localDataSize = 0;
430
    for (int i = 0; i < (int) params.size(); i++) {
431
432
433
434
435
436
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
437
438
439
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
440
    stringstream args;
441
    for (int i = 0; i < (int) params.size(); i++) {
442
        args << ", __global const ";
443
        args << params[i].getType();
444
        args << "* restrict global_";
445
446
        args << params[i].getName();
    }
447
    for (int i = 0; i < (int) arguments.size(); i++) {
448
449
450
451
452
453
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
454
                args << ", __global const ";
455
456
457
            else
                args << ", __constant ";
            args << arguments[i].getType();
458
            args << "* restrict ";
459
460
            args << arguments[i].getName();
        }
461
    }
462
463
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
464
    for (int i = 0; i < (int) params.size(); i++) {
465
        if (params[i].getNumComponents() == 1) {
466
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
467
468
469
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
470
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
471
        }
472
473
474
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
475
    for (int i = 0; i < (int) params.size(); i++) {
476
        if (params[i].getNumComponents() == 1) {
477
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
478
479
480
481
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
482
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
483
        }
484
485
486
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
487
    for (int i = 0; i < (int) params.size(); i++) {
488
489
490
491
492
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
493
        load1 << "[atom1];\n";
494
495
496
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
497
    for (int i = 0; i < (int) params.size(); i++) {
498
        if (params[i].getNumComponents() == 1) {
499
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
500
501
502
503
504
505
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
506
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
507
508
509
            }
            load2j<<");\n";
        }
510
    }
511
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
512
513
514
515
516
517
518
519
520
    map<string, string> defines;
    if (forceBufferPerAtomBlock)
        defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
521
522
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
523
524
525
526
527
    defines["FORCE_WORK_GROUP_SIZE"] = context.intToString(forceThreadBlockSize);
    defines["CUTOFF_SQUARED"] = context.doubleToString(cutoff*cutoff);
    defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
    defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
528
529
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
530
531
532
533
534
535
536
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else if (context.getSIMDWidth() == 32)
        file = OpenCLKernelSources::nonbonded_nvidia;
    else
        file = OpenCLKernelSources::nonbonded_default;
537
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
538
539
540
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
541

542
    int index = 0;
543
544
545
546
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
547
548
549
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
550
551
    kernel.setArg<cl::Buffer>(index++, exclusionIndices->getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusionRowIndices->getDeviceBuffer());
552
553
    kernel.setArg<cl_uint>(index++, startTileIndex);
    kernel.setArg<cl_uint>(index++, startTileIndex+numTiles);
554
    if (useCutoff) {
555
556
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
557
        index += 2; // The periodic box size arguments are set when the kernel is executed.
558
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
Peter Eastman's avatar
Peter Eastman committed
559
        kernel.setArg<cl::Buffer>(index++, interactionFlags->getDeviceBuffer());
560
    }
561
    else {
562
        kernel.setArg<cl_uint>(index++, numTiles);
563
    }
564
    for (int i = 0; i < (int) params.size(); i++) {
565
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
566
    }
567
    for (int i = 0; i < (int) arguments.size(); i++) {
568
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
569
    }
570
    return kernel;
571
}