OpenCLNonbondedUtilities.cpp 28 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2012 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/OpenMMException.h"
28
29
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
30
#include "OpenCLKernelSources.h"
31
#include "OpenCLExpressionUtilities.h"
32
#include <map>
33
34
#include <set>
#include <utility>
35
36
37
38

using namespace OpenMM;
using namespace std;

39
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), cutoff(-1.0), useCutoff(false), anyExclusions(false),
40
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusions(NULL), interactingTiles(NULL), interactionFlags(NULL),
41
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), nonbondedForceGroup(0) {
42
    // Decide how many thread blocks and force buffers to use.
43

44
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
45
    forceBufferPerAtomBlock = false;
46
47
48
49
50
51
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
52
        if (context.getSupports64BitGlobalAtomics()) {
53
            numForceThreadBlocks = 2*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
54
            forceThreadBlockSize = 256;
55
56
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
57
58
59
60
61
62
        }
        else {
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 128;
            numForceBuffers = numForceThreadBlocks;
        }
63
    }
64
    else {
65
66
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = OpenCLContext::ThreadBlockSize;
67
68
69
70
71
72
73
74
        if (context.getSupports64BitGlobalAtomics()) {
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
        }
        else {
            numForceBuffers = numForceThreadBlocks;
            if (numForceBuffers >= context.getNumAtomBlocks()) {
                // For small systems, it is more efficient to have one force buffer per block of 32 atoms instead of one per warp.
75

76
77
78
                forceBufferPerAtomBlock = true;
                numForceBuffers = context.getNumAtomBlocks();
            }
79
        }
80
    }
81
82
83
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
84
85
86
87
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
88
89
    if (exclusions != NULL)
        delete exclusions;
90
91
92
93
94
95
96
97
98
99
    if (interactingTiles != NULL)
        delete interactingTiles;
    if (interactionFlags != NULL)
        delete interactionFlags;
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
100
101
}

102
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
103
104
105
106
107
108
109
    if (cutoff != -1.0) {
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
        if (cutoffDistance != cutoff)
            throw OpenMMException("All Forces must use the same cutoff distance");
110
111
        if (forceGroup != nonbondedForceGroup)
            throw OpenMMException("All nonbonded forces must be in the same force group");
112
    }
113
114
115
116
117
118
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
    cutoff = cutoffDistance;
    kernelSource += kernel+"\n";
119
    nonbondedForceGroup = forceGroup;
120
121
122
123
124
125
126
127
128
129
130
131
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
132
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
133
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
134
135
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
136
137
            set<int> expectedExclusions;
            expectedExclusions.insert(atomExclusions[i].begin(), atomExclusions[i].end());
138
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
139
                if (expectedExclusions.find(exclusionList[i][j]) == expectedExclusions.end())
140
141
142
143
144
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
145
    else {
146
        atomExclusions = exclusionList;
147
148
        anyExclusions = true;
    }
149
150
151
152
153
154
}

void OpenCLNonbondedUtilities::initialize(const System& system) {
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
    
155
156
157
158
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
        
        atomExclusions.resize(context.getNumAtoms());
159
        for (int i = 0; i < (int) atomExclusions.size(); i++)
160
161
162
            atomExclusions[i].push_back(i);
    }

163
164
165
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
166
167
168
169
170
    int totalTiles = numAtomBlocks*(numAtomBlocks+1)/2;
    int numContexts = context.getPlatformData().contexts.size();
    startTileIndex = context.getContextIndex()*totalTiles/numContexts;
    int endTileIndex = (context.getContextIndex()+1)*totalTiles/numContexts;
    numTiles = endTileIndex-startTileIndex;
171
172
173

    // Build a list of indices for the tiles with exclusions.

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
    if (context.getPaddedNumAtoms() > context.getNumAtoms()) {
        for (int i = 0; i < numAtomBlocks; ++i)
            tilesWithExclusions.insert(make_pair(numAtomBlocks-1, i));
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
    int currentRow = 0;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        while (iter->first != currentRow)
            exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
        exclusionIndicesVec.push_back(iter->second);
    }
    exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
196
197
    exclusionIndices = OpenCLArray::create<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = OpenCLArray::create<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
198
199
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
200
201
202

    // Record the exclusion data.

203
    exclusions = OpenCLArray::create<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
204
205
206
207
208
209
210
211
212
213
214
    vector<cl_uint> exclusionVec(exclusions->getSize());
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
215
216
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
217
218
            }
            else {
219
220
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
221
222
223
224
225
226
227
228
229
230
231
232
233
            }
        }
    }

    // Mark all interactions that involve a padding atom as being excluded.

    for (int atom1 = context.getNumAtoms(); atom1 < context.getPaddedNumAtoms(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int atom2 = 0; atom2 < context.getPaddedNumAtoms(); ++atom2) {
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x >= y) {
234
235
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
236
237
            }
            if (y >= x) {
238
239
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
240
241
242
243
244
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
245
246
247
248

    // Create data structures for the neighbor list.

    if (useCutoff) {
249
250
251
252
253
254
255
        // Select a size for the arrays that hold the neighbor list.  This estimate is intentionally very
        // high, because if it ever is too small, we have to fall back to the N^2 algorithm.

        mm_float4 boxSize = context.getPeriodicBoxSize();
        int maxInteractingTiles = (int) (numTiles*(cutoff/boxSize.x+cutoff/boxSize.y+cutoff/boxSize.z));
        if (maxInteractingTiles > numTiles)
            maxInteractingTiles = numTiles;
256
257
        if (maxInteractingTiles < 1)
            maxInteractingTiles = 1;
258
259
260
        interactingTiles = OpenCLArray::create<mm_ushort2>(context, maxInteractingTiles, "interactingTiles");
        interactionFlags = OpenCLArray::create<cl_uint>(context, context.getSIMDWidth() == 32 ? maxInteractingTiles : (deviceIsCpu ? 2*maxInteractingTiles : 1), "interactionFlags");
        interactionCount = OpenCLArray::create<cl_uint>(context, 1, "interactionCount");
261
262
263
        int elementSize = (context.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
        blockCenter = new OpenCLArray(context, numAtomBlocks, elementSize, "blockCenter");
        blockBoundingBox = new OpenCLArray(context, numAtomBlocks, elementSize, "blockBoundingBox");
264
265
        vector<cl_uint> count(1, 0);
        interactionCount->upload(count);
266
    }
267
268
269

    // Create kernels.

270
    forceKernel = createInteractionKernel(kernelSource, parameters, arguments, true, true);
271
272
    if (useCutoff) {
        map<string, string> defines;
273
        defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
274
275
276
277
        if (forceBufferPerAtomBlock)
            defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
278
279
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        cl::Program interactingBlocksProgram = context.createProgram(file, defines);
280
281
        findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
        findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
282
283
284
        findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer());
285
        findBlockBoundsKernel.setArg<cl::Buffer>(6, interactionCount->getDeviceBuffer());
286
        findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
287
288
289
290
        if (context.getUseDoublePrecision())
            findInteractingBlocksKernel.setArg<cl_double>(0, cutoff*cutoff);
        else
            findInteractingBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
291
292
293
294
        findInteractingBlocksKernel.setArg<cl::Buffer>(3, blockCenter->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(4, blockBoundingBox->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
295
296
297
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
298
299
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
300
        if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
301
            findInteractionsWithinBlocksKernel = cl::Kernel(interactingBlocksProgram, "findInteractionsWithinBlocks");
302
303
304
305
            if (context.getUseDoublePrecision())
                findInteractionsWithinBlocksKernel.setArg<cl_double>(0, cutoff*cutoff);
            else
                findInteractionsWithinBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
306
307
308
309
310
311
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(5, blockCenter->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(6, blockBoundingBox->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(8, interactionCount->getDeviceBuffer());
312
            findInteractionsWithinBlocksKernel.setArg(9, 128*sizeof(cl_uint), NULL);
313
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, interactingTiles->getSize());
314
        }
315
    }
316
317
}

318
319
320
321
322
323
324
325
326
int OpenCLNonbondedUtilities::findExclusionIndex(int x, int y, const vector<cl_uint>& exclusionIndices, const vector<cl_uint>& exclusionRowIndices) {
    int start = exclusionRowIndices[x];
    int end = exclusionRowIndices[x+1];
    for (int i = start; i < end; i++)
        if (exclusionIndices[i] == y)
            return i*OpenCLContext::TileSize;
    throw OpenMMException("Internal error: exclusion in unexpected tile");
}

327
328
329
330
331
332
333
334
335
336
337
338
339
340
static void setPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxSize());
}

static void setInvPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getInvPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getInvPeriodicBoxSize());
}

341
342
343
void OpenCLNonbondedUtilities::prepareInteractions() {
    if (!useCutoff)
        return;
344
345
    if (usePeriodic) {
        mm_float4 box = context.getPeriodicBoxSize();
346
        double minAllowedSize = 1.999999*cutoff;
347
348
349
        if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
    }
350
351
352

    // Compute the neighbor list.

353
354
    setPeriodicBoxSizeArg(context, findBlockBoundsKernel, 1);
    setInvPeriodicBoxSizeArg(context, findBlockBoundsKernel, 2);
355
    context.executeKernel(findBlockBoundsKernel, context.getNumAtoms());
356
357
    setPeriodicBoxSizeArg(context, findInteractingBlocksKernel, 1);
    setInvPeriodicBoxSizeArg(context, findInteractingBlocksKernel, 2);
358
359
    context.executeKernel(findInteractingBlocksKernel, context.getNumAtoms(), deviceIsCpu ? 1 : -1);
    if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
360
361
        setPeriodicBoxSizeArg(context, findInteractionsWithinBlocksKernel, 1);
        setInvPeriodicBoxSizeArg(context, findInteractionsWithinBlocksKernel, 2);
362
        context.executeKernel(findInteractionsWithinBlocksKernel, context.getNumAtoms(), 128);
363
    }
364
365
366
}

void OpenCLNonbondedUtilities::computeInteractions() {
367
    if (cutoff != -1.0) {
368
        if (useCutoff) {
369
370
            setPeriodicBoxSizeArg(context, forceKernel, 10);
            setInvPeriodicBoxSizeArg(context, forceKernel, 11);
371
        }
372
        context.executeKernel(forceKernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
373
    }
374
375
}

376
377
378
void OpenCLNonbondedUtilities::updateNeighborListSize() {
    if (!useCutoff)
        return;
379
380
381
    unsigned int* pinnedInteractionCount = (unsigned int*) context.getPinnedBuffer();
    interactionCount->download(pinnedInteractionCount);
    if (pinnedInteractionCount[0] <= (unsigned int) interactingTiles->getSize())
382
383
384
385
386
        return;

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

387
    int newSize = (int) (1.2*pinnedInteractionCount[0]);
388
389
390
391
    int numTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (newSize > numTiles)
        newSize = numTiles;
    delete interactingTiles;
392
    interactingTiles = OpenCLArray::create<mm_ushort2>(context, newSize, "interactingTiles");
393
394
    forceKernel.setArg<cl::Buffer>(8, interactingTiles->getDeviceBuffer());
    forceKernel.setArg<cl_uint>(12, newSize);
395
    findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
396
    findInteractingBlocksKernel.setArg<cl_uint>(9, newSize);
Peter Eastman's avatar
Peter Eastman committed
397
    if (context.getSIMDWidth() == 32 || deviceIsCpu) {
398
        delete interactionFlags;
399
        interactionFlags = OpenCLArray::create<cl_uint>(context, deviceIsCpu ? 2*newSize : newSize, "interactionFlags");
400
        forceKernel.setArg<cl::Buffer>(13, interactionFlags->getDeviceBuffer());
401
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
402
403
404
405
406
		if (!deviceIsCpu) {
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, newSize);
		}
407
408
409
    }
}

410
411
412
413
414
void OpenCLNonbondedUtilities::setTileRange(int startTileIndex, int numTiles) {
    this->startTileIndex = startTileIndex;
    this->numTiles = numTiles;
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
415
416
    forceKernel.setArg<cl_uint>(6, startTileIndex);
    forceKernel.setArg<cl_uint>(7, startTileIndex+numTiles);
417
418
419
420
    if (useCutoff) {
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
    }
421
    else
422
        forceKernel.setArg<cl_uint>(8, numTiles);
423
424
}

425
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) const {
426
427
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
428
    const string suffixes[] = {"x", "y", "z", "w"};
429
    stringstream localData;
430
    int localDataSize = 0;
431
    for (int i = 0; i < (int) params.size(); i++) {
432
433
434
435
436
437
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
438
439
440
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
441
    stringstream args;
442
    for (int i = 0; i < (int) params.size(); i++) {
443
        args << ", __global const ";
444
        args << params[i].getType();
445
        args << "* restrict global_";
446
447
        args << params[i].getName();
    }
448
    for (int i = 0; i < (int) arguments.size(); i++) {
449
450
451
452
453
454
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
455
                args << ", __global const ";
456
457
458
            else
                args << ", __constant ";
            args << arguments[i].getType();
459
            args << "* restrict ";
460
461
            args << arguments[i].getName();
        }
462
    }
463
464
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
465
    for (int i = 0; i < (int) params.size(); i++) {
466
        if (params[i].getNumComponents() == 1) {
467
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
468
469
470
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
471
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
472
        }
473
474
475
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
476
    for (int i = 0; i < (int) params.size(); i++) {
477
        if (params[i].getNumComponents() == 1) {
478
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
479
480
481
482
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
483
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
484
        }
485
486
487
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
488
    for (int i = 0; i < (int) params.size(); i++) {
489
490
491
492
493
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
494
        load1 << "[atom1];\n";
495
496
497
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
498
    for (int i = 0; i < (int) params.size(); i++) {
499
        if (params[i].getNumComponents() == 1) {
500
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
501
502
503
504
505
506
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
507
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
508
509
510
            }
            load2j<<");\n";
        }
511
    }
512
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
513
514
515
516
517
518
519
520
521
    map<string, string> defines;
    if (forceBufferPerAtomBlock)
        defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
522
523
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
524
525
526
527
528
    defines["FORCE_WORK_GROUP_SIZE"] = context.intToString(forceThreadBlockSize);
    defines["CUTOFF_SQUARED"] = context.doubleToString(cutoff*cutoff);
    defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
    defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
529
530
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
531
532
533
534
535
536
537
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else if (context.getSIMDWidth() == 32)
        file = OpenCLKernelSources::nonbonded_nvidia;
    else
        file = OpenCLKernelSources::nonbonded_default;
538
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
539
540
541
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
542

543
    int index = 0;
544
545
546
547
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
548
549
550
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
551
552
    kernel.setArg<cl::Buffer>(index++, exclusionIndices->getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusionRowIndices->getDeviceBuffer());
553
554
    kernel.setArg<cl_uint>(index++, startTileIndex);
    kernel.setArg<cl_uint>(index++, startTileIndex+numTiles);
555
    if (useCutoff) {
556
557
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
558
        index += 2; // The periodic box size arguments are set when the kernel is executed.
559
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
Peter Eastman's avatar
Peter Eastman committed
560
        kernel.setArg<cl::Buffer>(index++, interactionFlags->getDeviceBuffer());
561
    }
562
    else {
563
        kernel.setArg<cl_uint>(index++, numTiles);
564
    }
565
    for (int i = 0; i < (int) params.size(); i++) {
566
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
567
    }
568
    for (int i = 0; i < (int) arguments.size(); i++) {
569
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
570
    }
571
    return kernel;
572
}