OpenCLNonbondedUtilities.cpp 27.8 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2012 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/OpenMMException.h"
28
29
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
30
#include "OpenCLKernelSources.h"
31
#include "OpenCLExpressionUtilities.h"
32
#include <map>
33
34
#include <set>
#include <utility>
35
36
37
38

using namespace OpenMM;
using namespace std;

39
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), cutoff(-1.0), useCutoff(false), anyExclusions(false),
40
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusions(NULL), interactingTiles(NULL), interactionFlags(NULL),
41
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), nonbondedForceGroup(0) {
42
    // Decide how many thread blocks and force buffers to use.
43

44
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
45
    forceBufferPerAtomBlock = false;
46
47
48
49
50
51
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
52
        if (context.getSupports64BitGlobalAtomics()) {
53
            numForceThreadBlocks = 2*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
54
            forceThreadBlockSize = 256;
55
56
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
57
58
59
60
61
62
        }
        else {
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 128;
            numForceBuffers = numForceThreadBlocks;
        }
63
    }
64
    else {
65
66
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = OpenCLContext::ThreadBlockSize;
67
68
69
70
71
72
73
74
        if (context.getSupports64BitGlobalAtomics()) {
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
        }
        else {
            numForceBuffers = numForceThreadBlocks;
            if (numForceBuffers >= context.getNumAtomBlocks()) {
                // For small systems, it is more efficient to have one force buffer per block of 32 atoms instead of one per warp.
75

76
77
78
                forceBufferPerAtomBlock = true;
                numForceBuffers = context.getNumAtomBlocks();
            }
79
        }
80
    }
81
82
83
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
84
85
86
87
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
88
89
    if (exclusions != NULL)
        delete exclusions;
90
91
92
93
94
95
96
97
98
99
    if (interactingTiles != NULL)
        delete interactingTiles;
    if (interactionFlags != NULL)
        delete interactionFlags;
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
100
101
}

102
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
103
104
105
106
107
108
109
    if (cutoff != -1.0) {
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
        if (cutoffDistance != cutoff)
            throw OpenMMException("All Forces must use the same cutoff distance");
110
111
        if (forceGroup != nonbondedForceGroup)
            throw OpenMMException("All nonbonded forces must be in the same force group");
112
    }
113
114
115
116
117
118
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
    cutoff = cutoffDistance;
    kernelSource += kernel+"\n";
119
    nonbondedForceGroup = forceGroup;
120
121
122
123
124
125
126
127
128
129
130
131
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
132
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
133
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
134
135
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
136
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
137
138
139
140
141
142
                if (exclusionList[i][j] != atomExclusions[i][j])
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
143
    else {
144
        atomExclusions = exclusionList;
145
146
        anyExclusions = true;
    }
147
148
149
150
151
152
}

void OpenCLNonbondedUtilities::initialize(const System& system) {
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
    
153
154
155
156
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
        
        atomExclusions.resize(context.getNumAtoms());
157
        for (int i = 0; i < (int) atomExclusions.size(); i++)
158
159
160
            atomExclusions[i].push_back(i);
    }

161
162
163
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
164
165
166
167
168
    int totalTiles = numAtomBlocks*(numAtomBlocks+1)/2;
    int numContexts = context.getPlatformData().contexts.size();
    startTileIndex = context.getContextIndex()*totalTiles/numContexts;
    int endTileIndex = (context.getContextIndex()+1)*totalTiles/numContexts;
    numTiles = endTileIndex-startTileIndex;
169
170
171

    // Build a list of indices for the tiles with exclusions.

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
    if (context.getPaddedNumAtoms() > context.getNumAtoms()) {
        for (int i = 0; i < numAtomBlocks; ++i)
            tilesWithExclusions.insert(make_pair(numAtomBlocks-1, i));
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
    int currentRow = 0;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        while (iter->first != currentRow)
            exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
        exclusionIndicesVec.push_back(iter->second);
    }
    exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
194
195
    exclusionIndices = OpenCLArray::create<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = OpenCLArray::create<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
196
197
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
198
199
200

    // Record the exclusion data.

201
    exclusions = OpenCLArray::create<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
202
203
204
205
206
207
208
209
210
211
212
    vector<cl_uint> exclusionVec(exclusions->getSize());
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
213
214
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
215
216
            }
            else {
217
218
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
219
220
221
222
223
224
225
226
227
228
229
230
231
            }
        }
    }

    // Mark all interactions that involve a padding atom as being excluded.

    for (int atom1 = context.getNumAtoms(); atom1 < context.getPaddedNumAtoms(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int atom2 = 0; atom2 < context.getPaddedNumAtoms(); ++atom2) {
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x >= y) {
232
233
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
234
235
            }
            if (y >= x) {
236
237
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
238
239
240
241
242
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
243
244
245
246

    // Create data structures for the neighbor list.

    if (useCutoff) {
247
248
249
250
251
252
253
        // Select a size for the arrays that hold the neighbor list.  This estimate is intentionally very
        // high, because if it ever is too small, we have to fall back to the N^2 algorithm.

        mm_float4 boxSize = context.getPeriodicBoxSize();
        int maxInteractingTiles = (int) (numTiles*(cutoff/boxSize.x+cutoff/boxSize.y+cutoff/boxSize.z));
        if (maxInteractingTiles > numTiles)
            maxInteractingTiles = numTiles;
254
255
        if (maxInteractingTiles < 1)
            maxInteractingTiles = 1;
256
257
258
        interactingTiles = OpenCLArray::create<mm_ushort2>(context, maxInteractingTiles, "interactingTiles");
        interactionFlags = OpenCLArray::create<cl_uint>(context, context.getSIMDWidth() == 32 ? maxInteractingTiles : (deviceIsCpu ? 2*maxInteractingTiles : 1), "interactionFlags");
        interactionCount = OpenCLArray::create<cl_uint>(context, 1, "interactionCount");
259
260
261
        int elementSize = (context.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
        blockCenter = new OpenCLArray(context, numAtomBlocks, elementSize, "blockCenter");
        blockBoundingBox = new OpenCLArray(context, numAtomBlocks, elementSize, "blockBoundingBox");
262
263
        vector<cl_uint> count(1, 0);
        interactionCount->upload(count);
264
    }
265
266
267

    // Create kernels.

268
    forceKernel = createInteractionKernel(kernelSource, parameters, arguments, true, true);
269
270
    if (useCutoff) {
        map<string, string> defines;
271
        defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
272
273
274
275
        if (forceBufferPerAtomBlock)
            defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
276
277
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        cl::Program interactingBlocksProgram = context.createProgram(file, defines);
278
279
        findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
        findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
280
281
282
        findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer());
283
        findBlockBoundsKernel.setArg<cl::Buffer>(6, interactionCount->getDeviceBuffer());
284
        findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
285
286
287
288
        if (context.getUseDoublePrecision())
            findInteractingBlocksKernel.setArg<cl_double>(0, cutoff*cutoff);
        else
            findInteractingBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
289
290
291
292
        findInteractingBlocksKernel.setArg<cl::Buffer>(3, blockCenter->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(4, blockBoundingBox->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
293
294
295
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
296
297
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
298
        if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
299
            findInteractionsWithinBlocksKernel = cl::Kernel(interactingBlocksProgram, "findInteractionsWithinBlocks");
300
301
302
303
            if (context.getUseDoublePrecision())
                findInteractionsWithinBlocksKernel.setArg<cl_double>(0, cutoff*cutoff);
            else
                findInteractionsWithinBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
304
305
306
307
308
309
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(5, blockCenter->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(6, blockBoundingBox->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(8, interactionCount->getDeviceBuffer());
310
            findInteractionsWithinBlocksKernel.setArg(9, 128*sizeof(cl_uint), NULL);
311
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, interactingTiles->getSize());
312
        }
313
    }
314
315
}

316
317
318
319
320
321
322
323
324
int OpenCLNonbondedUtilities::findExclusionIndex(int x, int y, const vector<cl_uint>& exclusionIndices, const vector<cl_uint>& exclusionRowIndices) {
    int start = exclusionRowIndices[x];
    int end = exclusionRowIndices[x+1];
    for (int i = start; i < end; i++)
        if (exclusionIndices[i] == y)
            return i*OpenCLContext::TileSize;
    throw OpenMMException("Internal error: exclusion in unexpected tile");
}

325
326
327
328
329
330
331
332
333
334
335
336
337
338
static void setPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxSize());
}

static void setInvPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getInvPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getInvPeriodicBoxSize());
}

339
340
341
void OpenCLNonbondedUtilities::prepareInteractions() {
    if (!useCutoff)
        return;
342
343
    if (usePeriodic) {
        mm_float4 box = context.getPeriodicBoxSize();
344
        double minAllowedSize = 1.999999*cutoff;
345
346
347
        if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
    }
348
349
350

    // Compute the neighbor list.

351
352
    setPeriodicBoxSizeArg(context, findBlockBoundsKernel, 1);
    setInvPeriodicBoxSizeArg(context, findBlockBoundsKernel, 2);
353
    context.executeKernel(findBlockBoundsKernel, context.getNumAtoms());
354
355
    setPeriodicBoxSizeArg(context, findInteractingBlocksKernel, 1);
    setInvPeriodicBoxSizeArg(context, findInteractingBlocksKernel, 2);
356
357
    context.executeKernel(findInteractingBlocksKernel, context.getNumAtoms(), deviceIsCpu ? 1 : -1);
    if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
358
359
        setPeriodicBoxSizeArg(context, findInteractionsWithinBlocksKernel, 1);
        setInvPeriodicBoxSizeArg(context, findInteractionsWithinBlocksKernel, 2);
360
        context.executeKernel(findInteractionsWithinBlocksKernel, context.getNumAtoms(), 128);
361
    }
362
363
364
}

void OpenCLNonbondedUtilities::computeInteractions() {
365
    if (cutoff != -1.0) {
366
        if (useCutoff) {
367
368
            setPeriodicBoxSizeArg(context, forceKernel, 10);
            setInvPeriodicBoxSizeArg(context, forceKernel, 11);
369
        }
370
        context.executeKernel(forceKernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
371
    }
372
373
}

374
375
376
void OpenCLNonbondedUtilities::updateNeighborListSize() {
    if (!useCutoff)
        return;
377
378
379
    unsigned int* pinnedInteractionCount = (unsigned int*) context.getPinnedBuffer();
    interactionCount->download(pinnedInteractionCount);
    if (pinnedInteractionCount[0] <= (unsigned int) interactingTiles->getSize())
380
381
382
383
384
        return;

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

385
    int newSize = (int) (1.2*pinnedInteractionCount[0]);
386
387
388
389
    int numTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (newSize > numTiles)
        newSize = numTiles;
    delete interactingTiles;
390
    interactingTiles = OpenCLArray::create<mm_ushort2>(context, newSize, "interactingTiles");
391
392
    forceKernel.setArg<cl::Buffer>(8, interactingTiles->getDeviceBuffer());
    forceKernel.setArg<cl_uint>(12, newSize);
393
    findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
394
    findInteractingBlocksKernel.setArg<cl_uint>(9, newSize);
Peter Eastman's avatar
Peter Eastman committed
395
    if (context.getSIMDWidth() == 32 || deviceIsCpu) {
396
        delete interactionFlags;
397
        interactionFlags = OpenCLArray::create<cl_uint>(context, deviceIsCpu ? 2*newSize : newSize, "interactionFlags");
398
        forceKernel.setArg<cl::Buffer>(13, interactionFlags->getDeviceBuffer());
399
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
400
401
402
403
404
		if (!deviceIsCpu) {
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, newSize);
		}
405
406
407
    }
}

408
409
410
411
412
void OpenCLNonbondedUtilities::setTileRange(int startTileIndex, int numTiles) {
    this->startTileIndex = startTileIndex;
    this->numTiles = numTiles;
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
413
414
    forceKernel.setArg<cl_uint>(6, startTileIndex);
    forceKernel.setArg<cl_uint>(7, startTileIndex+numTiles);
415
416
417
418
    if (useCutoff) {
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
    }
419
    else
420
        forceKernel.setArg<cl_uint>(8, numTiles);
421
422
}

423
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) const {
424
425
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
426
    const string suffixes[] = {"x", "y", "z", "w"};
427
    stringstream localData;
428
    int localDataSize = 0;
429
    for (int i = 0; i < (int) params.size(); i++) {
430
431
432
433
434
435
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
436
437
438
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
439
    stringstream args;
440
    for (int i = 0; i < (int) params.size(); i++) {
441
        args << ", __global const ";
442
        args << params[i].getType();
443
        args << "* restrict global_";
444
445
        args << params[i].getName();
    }
446
    for (int i = 0; i < (int) arguments.size(); i++) {
447
448
449
450
451
452
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
453
                args << ", __global const ";
454
455
456
            else
                args << ", __constant ";
            args << arguments[i].getType();
457
            args << "* restrict ";
458
459
            args << arguments[i].getName();
        }
460
    }
461
462
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
463
    for (int i = 0; i < (int) params.size(); i++) {
464
        if (params[i].getNumComponents() == 1) {
465
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
466
467
468
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
469
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
470
        }
471
472
473
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
474
    for (int i = 0; i < (int) params.size(); i++) {
475
        if (params[i].getNumComponents() == 1) {
476
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
477
478
479
480
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
481
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
482
        }
483
484
485
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
486
    for (int i = 0; i < (int) params.size(); i++) {
487
488
489
490
491
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
492
        load1 << "[atom1];\n";
493
494
495
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
496
    for (int i = 0; i < (int) params.size(); i++) {
497
        if (params[i].getNumComponents() == 1) {
498
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
499
500
501
502
503
504
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
505
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
506
507
508
            }
            load2j<<");\n";
        }
509
    }
510
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
511
512
513
514
515
516
517
518
519
    map<string, string> defines;
    if (forceBufferPerAtomBlock)
        defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
520
521
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
522
523
524
525
526
    defines["FORCE_WORK_GROUP_SIZE"] = context.intToString(forceThreadBlockSize);
    defines["CUTOFF_SQUARED"] = context.doubleToString(cutoff*cutoff);
    defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
    defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
527
528
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
529
530
531
532
533
534
535
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else if (context.getSIMDWidth() == 32)
        file = OpenCLKernelSources::nonbonded_nvidia;
    else
        file = OpenCLKernelSources::nonbonded_default;
536
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
537
538
539
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
540

541
    int index = 0;
542
543
544
545
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
546
547
548
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
549
550
    kernel.setArg<cl::Buffer>(index++, exclusionIndices->getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusionRowIndices->getDeviceBuffer());
551
552
    kernel.setArg<cl_uint>(index++, startTileIndex);
    kernel.setArg<cl_uint>(index++, startTileIndex+numTiles);
553
    if (useCutoff) {
554
555
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
556
        index += 2; // The periodic box size arguments are set when the kernel is executed.
557
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
Peter Eastman's avatar
Peter Eastman committed
558
        kernel.setArg<cl::Buffer>(index++, interactionFlags->getDeviceBuffer());
559
    }
560
    else {
561
        kernel.setArg<cl_uint>(index++, numTiles);
562
    }
563
    for (int i = 0; i < (int) params.size(); i++) {
564
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
565
    }
566
    for (int i = 0; i < (int) arguments.size(); i++) {
567
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
568
    }
569
    return kernel;
570
}