"plugins/vscode:/vscode.git/clone" did not exist on "3e974a97823fad9e71ca9b4f45796852725f0a77"
OpenCLNonbondedUtilities.cpp 26.9 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2011 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/OpenMMException.h"
28
29
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
30
#include "OpenCLKernelSources.h"
31
#include "OpenCLExpressionUtilities.h"
32
#include <map>
33
34
#include <set>
#include <utility>
35
36
37
38

using namespace OpenMM;
using namespace std;

39
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), cutoff(-1.0), useCutoff(false), anyExclusions(false),
40
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusions(NULL), interactingTiles(NULL), interactionFlags(NULL),
41
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), nonbondedForceGroup(0) {
42
    // Decide how many thread blocks and force buffers to use.
43

44
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
45
    forceBufferPerAtomBlock = false;
46
47
48
49
50
51
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
52
        if (context.getSupports64BitGlobalAtomics()) {
53
            numForceThreadBlocks = 2*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
54
            forceThreadBlockSize = 256;
55
56
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
57
58
59
60
61
62
        }
        else {
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 128;
            numForceBuffers = numForceThreadBlocks;
        }
63
    }
64
    else {
65
66
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = OpenCLContext::ThreadBlockSize;
67
68
69
70
71
72
73
74
        if (context.getSupports64BitGlobalAtomics()) {
            // Even though using longForceBuffer, still need a single forceBuffer for the reduceForces kernel to convert the long results into float4 which will be used by later kernels.
            numForceBuffers = 1;
        }
        else {
            numForceBuffers = numForceThreadBlocks;
            if (numForceBuffers >= context.getNumAtomBlocks()) {
                // For small systems, it is more efficient to have one force buffer per block of 32 atoms instead of one per warp.
75

76
77
78
                forceBufferPerAtomBlock = true;
                numForceBuffers = context.getNumAtomBlocks();
            }
79
        }
80
    }
81
82
83
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
84
85
86
87
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
88
89
    if (exclusions != NULL)
        delete exclusions;
90
91
92
93
94
95
96
97
98
99
    if (interactingTiles != NULL)
        delete interactingTiles;
    if (interactionFlags != NULL)
        delete interactionFlags;
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
100
101
}

102
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
103
104
105
106
107
108
109
    if (cutoff != -1.0) {
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
        if (cutoffDistance != cutoff)
            throw OpenMMException("All Forces must use the same cutoff distance");
110
111
        if (forceGroup != nonbondedForceGroup)
            throw OpenMMException("All nonbonded forces must be in the same force group");
112
    }
113
114
115
116
117
118
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
    cutoff = cutoffDistance;
    kernelSource += kernel+"\n";
119
    nonbondedForceGroup = forceGroup;
120
121
122
123
124
125
126
127
128
129
130
131
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
132
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
133
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
134
135
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
136
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
137
138
139
140
141
142
                if (exclusionList[i][j] != atomExclusions[i][j])
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
143
    else {
144
        atomExclusions = exclusionList;
145
146
        anyExclusions = true;
    }
147
148
149
150
151
152
}

void OpenCLNonbondedUtilities::initialize(const System& system) {
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
    
153
154
155
156
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
        
        atomExclusions.resize(context.getNumAtoms());
157
        for (int i = 0; i < (int) atomExclusions.size(); i++)
158
159
160
            atomExclusions[i].push_back(i);
    }

161
162
163
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
164
165
166
167
168
    int totalTiles = numAtomBlocks*(numAtomBlocks+1)/2;
    int numContexts = context.getPlatformData().contexts.size();
    startTileIndex = context.getContextIndex()*totalTiles/numContexts;
    int endTileIndex = (context.getContextIndex()+1)*totalTiles/numContexts;
    numTiles = endTileIndex-startTileIndex;
169
170
171

    // Build a list of indices for the tiles with exclusions.

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
    if (context.getPaddedNumAtoms() > context.getNumAtoms()) {
        for (int i = 0; i < numAtomBlocks; ++i)
            tilesWithExclusions.insert(make_pair(numAtomBlocks-1, i));
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
    int currentRow = 0;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        while (iter->first != currentRow)
            exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
        exclusionIndicesVec.push_back(iter->second);
    }
    exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
    exclusionIndices = new OpenCLArray<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = new OpenCLArray<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
198
199
200

    // Record the exclusion data.

201
    exclusions = new OpenCLArray<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
202
203
204
205
206
207
208
209
210
211
212
    vector<cl_uint> exclusionVec(exclusions->getSize());
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
213
214
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
215
216
            }
            else {
217
218
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
219
220
221
222
223
224
225
226
227
228
229
230
231
            }
        }
    }

    // Mark all interactions that involve a padding atom as being excluded.

    for (int atom1 = context.getNumAtoms(); atom1 < context.getPaddedNumAtoms(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int atom2 = 0; atom2 < context.getPaddedNumAtoms(); ++atom2) {
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x >= y) {
232
233
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
234
235
            }
            if (y >= x) {
236
237
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
238
239
240
241
242
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
243
244
245
246

    // Create data structures for the neighbor list.

    if (useCutoff) {
247
248
249
250
251
252
253
        // Select a size for the arrays that hold the neighbor list.  This estimate is intentionally very
        // high, because if it ever is too small, we have to fall back to the N^2 algorithm.

        mm_float4 boxSize = context.getPeriodicBoxSize();
        int maxInteractingTiles = (int) (numTiles*(cutoff/boxSize.x+cutoff/boxSize.y+cutoff/boxSize.z));
        if (maxInteractingTiles > numTiles)
            maxInteractingTiles = numTiles;
254
255
        if (maxInteractingTiles < 1)
            maxInteractingTiles = 1;
256
        interactingTiles = new OpenCLArray<mm_ushort2>(context, maxInteractingTiles, "interactingTiles");
257
        interactionFlags = new OpenCLArray<cl_uint>(context, context.getSIMDWidth() == 32 ? maxInteractingTiles : (deviceIsCpu ? 2*maxInteractingTiles : 1), "interactionFlags");
258
        interactionCount = new OpenCLArray<cl_uint>(context, 1, "interactionCount", true);
259
260
        blockCenter = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockCenter");
        blockBoundingBox = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockBoundingBox");
261
262
        interactionCount->set(0, 0);
        interactionCount->upload();
263
    }
264
265
266

    // Create kernels.

267
    forceKernel = createInteractionKernel(kernelSource, parameters, arguments, true, true);
268
269
    if (useCutoff) {
        map<string, string> defines;
270
        defines["NUM_BLOCKS"] = OpenCLExpressionUtilities::intToString(context.getNumAtomBlocks());
271
272
273
274
        if (forceBufferPerAtomBlock)
            defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
275
276
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        cl::Program interactingBlocksProgram = context.createProgram(file, defines);
277
278
        findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
        findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
279
280
281
        findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer());
282
        findBlockBoundsKernel.setArg<cl::Buffer>(6, interactionCount->getDeviceBuffer());
283
        findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
284
285
286
287
288
        findInteractingBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
        findInteractingBlocksKernel.setArg<cl::Buffer>(3, blockCenter->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(4, blockBoundingBox->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
289
290
291
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
292
293
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
294
        if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
295
296
297
298
299
300
301
302
            findInteractionsWithinBlocksKernel = cl::Kernel(interactingBlocksProgram, "findInteractionsWithinBlocks");
            findInteractionsWithinBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(5, blockCenter->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(6, blockBoundingBox->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(8, interactionCount->getDeviceBuffer());
303
            findInteractionsWithinBlocksKernel.setArg(9, 128*sizeof(cl_uint), NULL);
304
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, interactingTiles->getSize());
305
        }
306
    }
307
308
}

309
310
311
312
313
314
315
316
317
int OpenCLNonbondedUtilities::findExclusionIndex(int x, int y, const vector<cl_uint>& exclusionIndices, const vector<cl_uint>& exclusionRowIndices) {
    int start = exclusionRowIndices[x];
    int end = exclusionRowIndices[x+1];
    for (int i = start; i < end; i++)
        if (exclusionIndices[i] == y)
            return i*OpenCLContext::TileSize;
    throw OpenMMException("Internal error: exclusion in unexpected tile");
}

318
319
320
void OpenCLNonbondedUtilities::prepareInteractions() {
    if (!useCutoff)
        return;
321
322
323
324
325
326
    if (usePeriodic) {
        mm_float4 box = context.getPeriodicBoxSize();
        double minAllowedSize = 2*cutoff;
        if (box.x < minAllowedSize || box.y < minAllowedSize || box.z < minAllowedSize)
            throw OpenMMException("The periodic box size has decreased to less than twice the nonbonded cutoff.");
    }
327
328
329

    // Compute the neighbor list.

330
331
    findBlockBoundsKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
    findBlockBoundsKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
332
    context.executeKernel(findBlockBoundsKernel, context.getNumAtoms());
333
334
    findInteractingBlocksKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
    findInteractingBlocksKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
335
336
    context.executeKernel(findInteractingBlocksKernel, context.getNumAtoms(), deviceIsCpu ? 1 : -1);
    if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
337
338
        findInteractionsWithinBlocksKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
        findInteractionsWithinBlocksKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
339
        context.executeKernel(findInteractionsWithinBlocksKernel, context.getNumAtoms(), 128);
340
    }
341
342
343
}

void OpenCLNonbondedUtilities::computeInteractions() {
344
    if (cutoff != -1.0) {
345
        if (useCutoff) {
346
347
            forceKernel.setArg<mm_float4>(10, context.getPeriodicBoxSize());
            forceKernel.setArg<mm_float4>(11, context.getInvPeriodicBoxSize());
348
        }
349
        context.executeKernel(forceKernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
350
    }
351
352
}

353
354
355
356
void OpenCLNonbondedUtilities::updateNeighborListSize() {
    if (!useCutoff)
        return;
    interactionCount->download();
357
    if (interactionCount->get(0) <= (unsigned int) interactingTiles->getSize())
358
359
360
361
362
363
364
365
366
367
368
        return;

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

    int newSize = (int) (1.2*interactionCount->get(0));
    int numTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (newSize > numTiles)
        newSize = numTiles;
    delete interactingTiles;
    interactingTiles = new OpenCLArray<mm_ushort2>(context, newSize, "interactingTiles");
369
370
    forceKernel.setArg<cl::Buffer>(8, interactingTiles->getDeviceBuffer());
    forceKernel.setArg<cl_uint>(12, newSize);
371
    findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
372
    findInteractingBlocksKernel.setArg<cl_uint>(9, newSize);
Peter Eastman's avatar
Peter Eastman committed
373
    if (context.getSIMDWidth() == 32 || deviceIsCpu) {
374
        delete interactionFlags;
Peter Eastman's avatar
Peter Eastman committed
375
        interactionFlags = new OpenCLArray<cl_uint>(context, deviceIsCpu ? 2*newSize : newSize, "interactionFlags");
376
        forceKernel.setArg<cl::Buffer>(13, interactionFlags->getDeviceBuffer());
377
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
378
379
380
381
382
383
        findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
        findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, newSize);
    }
}

384
385
386
387
388
void OpenCLNonbondedUtilities::setTileRange(int startTileIndex, int numTiles) {
    this->startTileIndex = startTileIndex;
    this->numTiles = numTiles;
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
389
390
    forceKernel.setArg<cl_uint>(6, startTileIndex);
    forceKernel.setArg<cl_uint>(7, startTileIndex+numTiles);
391
392
393
394
    if (useCutoff) {
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
    }
395
    else
396
        forceKernel.setArg<cl_uint>(8, numTiles);
397
398
}

399
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) const {
400
401
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
402
    const string suffixes[] = {"x", "y", "z", "w"};
403
    stringstream localData;
404
    int localDataSize = 0;
405
    for (int i = 0; i < (int) params.size(); i++) {
406
407
408
409
410
411
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
412
413
414
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
415
    stringstream args;
416
    for (int i = 0; i < (int) params.size(); i++) {
417
        args << ", __global const ";
418
        args << params[i].getType();
419
        args << "* restrict global_";
420
421
        args << params[i].getName();
    }
422
    for (int i = 0; i < (int) arguments.size(); i++) {
423
424
425
426
427
428
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
429
                args << ", __global const ";
430
431
432
            else
                args << ", __constant ";
            args << arguments[i].getType();
433
            args << "* restrict ";
434
435
            args << arguments[i].getName();
        }
436
    }
437
438
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
439
    for (int i = 0; i < (int) params.size(); i++) {
440
        if (params[i].getNumComponents() == 1) {
441
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
442
443
444
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
445
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
446
        }
447
448
449
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
450
    for (int i = 0; i < (int) params.size(); i++) {
451
        if (params[i].getNumComponents() == 1) {
452
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
453
454
455
456
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
457
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
458
        }
459
460
461
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
462
    for (int i = 0; i < (int) params.size(); i++) {
463
464
465
466
467
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
468
        load1 << "[atom1];\n";
469
470
471
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
472
    for (int i = 0; i < (int) params.size(); i++) {
473
        if (params[i].getNumComponents() == 1) {
474
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
475
476
477
478
479
480
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
481
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
482
483
484
            }
            load2j<<");\n";
        }
485
    }
486
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
487
488
489
490
491
492
493
494
495
    map<string, string> defines;
    if (forceBufferPerAtomBlock)
        defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
496
497
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
498
    defines["FORCE_WORK_GROUP_SIZE"] = OpenCLExpressionUtilities::intToString(forceThreadBlockSize);
499
    defines["CUTOFF_SQUARED"] = OpenCLExpressionUtilities::doubleToString(cutoff*cutoff);
500
501
    defines["NUM_ATOMS"] = OpenCLExpressionUtilities::intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = OpenCLExpressionUtilities::intToString(context.getPaddedNumAtoms());
502
    defines["NUM_BLOCKS"] = OpenCLExpressionUtilities::intToString(context.getNumAtomBlocks());
503
504
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
505
506
507
508
509
510
511
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else if (context.getSIMDWidth() == 32)
        file = OpenCLKernelSources::nonbonded_nvidia;
    else
        file = OpenCLKernelSources::nonbonded_default;
512
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
513
514
515
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
516

517
    int index = 0;
518
519
520
521
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
522
523
524
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
525
526
    kernel.setArg<cl::Buffer>(index++, exclusionIndices->getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusionRowIndices->getDeviceBuffer());
527
528
    kernel.setArg<cl_uint>(index++, startTileIndex);
    kernel.setArg<cl_uint>(index++, startTileIndex+numTiles);
529
    if (useCutoff) {
530
531
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
532
        index += 2; // The periodic box size arguments are set when the kernel is executed.
533
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
Peter Eastman's avatar
Peter Eastman committed
534
        kernel.setArg<cl::Buffer>(index++, interactionFlags->getDeviceBuffer());
535
    }
536
    else {
537
        kernel.setArg<cl_uint>(index++, numTiles);
538
    }
539
    for (int i = 0; i < (int) params.size(); i++) {
540
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
541
    }
542
    for (int i = 0; i < (int) arguments.size(); i++) {
543
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
544
    }
545
    return kernel;
546
}