OpenCLNonbondedUtilities.cpp 26.1 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2011 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
29
#include "OpenCLKernelSources.h"
30
#include "OpenCLExpressionUtilities.h"
31
#include <map>
32
33
#include <set>
#include <utility>
34
35
36
37

using namespace OpenMM;
using namespace std;

38
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), cutoff(-1.0), useCutoff(false), anyExclusions(false),
39
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusions(NULL), interactingTiles(NULL), interactionFlags(NULL),
40
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), nonbondedForceGroup(0) {
41
    // Decide how many thread blocks and force buffers to use.
42

43
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
44
    forceBufferPerAtomBlock = false;
45
46
47
48
49
50
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
51
        if (context.getSupports64BitGlobalAtomics()) {
52
            numForceThreadBlocks = 2*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
53
54
55
56
57
58
59
60
            forceThreadBlockSize = 256;
            numForceBuffers = 2;
        }
        else {
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 128;
            numForceBuffers = numForceThreadBlocks;
        }
61
    }
62
    else {
63
64
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = OpenCLContext::ThreadBlockSize;
65
        numForceBuffers = numForceThreadBlocks;
66
67
        if (numForceBuffers >= context.getNumAtomBlocks()) {
            // For small systems, it is more efficient to have one force buffer per block of 32 atoms instead of one per warp.
68

69
70
71
            forceBufferPerAtomBlock = true;
            numForceBuffers = context.getNumAtomBlocks();
        }
72
    }
73
74
75
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
76
77
78
79
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
80
81
    if (exclusions != NULL)
        delete exclusions;
82
83
84
85
86
87
88
89
90
91
    if (interactingTiles != NULL)
        delete interactingTiles;
    if (interactionFlags != NULL)
        delete interactionFlags;
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
92
93
}

94
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel, int forceGroup) {
95
96
97
98
99
100
101
    if (cutoff != -1.0) {
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
        if (cutoffDistance != cutoff)
            throw OpenMMException("All Forces must use the same cutoff distance");
102
103
        if (forceGroup != nonbondedForceGroup)
            throw OpenMMException("All nonbonded forces must be in the same force group");
104
    }
105
106
107
108
109
110
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
    cutoff = cutoffDistance;
    kernelSource += kernel+"\n";
111
    nonbondedForceGroup = forceGroup;
112
113
114
115
116
117
118
119
120
121
122
123
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
124
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
125
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
126
127
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
128
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
129
130
131
132
133
134
                if (exclusionList[i][j] != atomExclusions[i][j])
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
135
    else {
136
        atomExclusions = exclusionList;
137
138
        anyExclusions = true;
    }
139
140
141
142
143
144
}

void OpenCLNonbondedUtilities::initialize(const System& system) {
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
    
145
146
147
148
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
        
        atomExclusions.resize(context.getNumAtoms());
149
        for (int i = 0; i < (int) atomExclusions.size(); i++)
150
151
152
            atomExclusions[i].push_back(i);
    }

153
154
155
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
156
157
158
159
160
    int totalTiles = numAtomBlocks*(numAtomBlocks+1)/2;
    int numContexts = context.getPlatformData().contexts.size();
    startTileIndex = context.getContextIndex()*totalTiles/numContexts;
    int endTileIndex = (context.getContextIndex()+1)*totalTiles/numContexts;
    numTiles = endTileIndex-startTileIndex;
161
162
163

    // Build a list of indices for the tiles with exclusions.

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
    if (context.getPaddedNumAtoms() > context.getNumAtoms()) {
        for (int i = 0; i < numAtomBlocks; ++i)
            tilesWithExclusions.insert(make_pair(numAtomBlocks-1, i));
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
    int currentRow = 0;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        while (iter->first != currentRow)
            exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
        exclusionIndicesVec.push_back(iter->second);
    }
    exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
    exclusionIndices = new OpenCLArray<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = new OpenCLArray<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
190
191
192

    // Record the exclusion data.

193
    exclusions = new OpenCLArray<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
194
195
196
197
198
199
200
201
202
203
204
    vector<cl_uint> exclusionVec(exclusions->getSize());
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
205
206
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
207
208
            }
            else {
209
210
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
211
212
213
214
215
216
217
218
219
220
221
222
223
            }
        }
    }

    // Mark all interactions that involve a padding atom as being excluded.

    for (int atom1 = context.getNumAtoms(); atom1 < context.getPaddedNumAtoms(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int atom2 = 0; atom2 < context.getPaddedNumAtoms(); ++atom2) {
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x >= y) {
224
225
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
226
227
            }
            if (y >= x) {
228
229
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
230
231
232
233
234
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
235
236
237
238

    // Create data structures for the neighbor list.

    if (useCutoff) {
239
240
241
242
243
244
245
        // Select a size for the arrays that hold the neighbor list.  This estimate is intentionally very
        // high, because if it ever is too small, we have to fall back to the N^2 algorithm.

        mm_float4 boxSize = context.getPeriodicBoxSize();
        int maxInteractingTiles = (int) (numTiles*(cutoff/boxSize.x+cutoff/boxSize.y+cutoff/boxSize.z));
        if (maxInteractingTiles > numTiles)
            maxInteractingTiles = numTiles;
246
247
        if (maxInteractingTiles < 1)
            maxInteractingTiles = 1;
248
        interactingTiles = new OpenCLArray<mm_ushort2>(context, maxInteractingTiles, "interactingTiles");
249
        interactionFlags = new OpenCLArray<cl_uint>(context, context.getSIMDWidth() == 32 ? maxInteractingTiles : (deviceIsCpu ? 2*maxInteractingTiles : 1), "interactionFlags");
250
        interactionCount = new OpenCLArray<cl_uint>(context, 1, "interactionCount", true);
251
252
        blockCenter = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockCenter");
        blockBoundingBox = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockBoundingBox");
253
254
        interactionCount->set(0, 0);
        interactionCount->upload();
255
    }
256
257
258

    // Create kernels.

259
    forceKernel = createInteractionKernel(kernelSource, parameters, arguments, true, true);
260
261
    if (useCutoff) {
        map<string, string> defines;
262
        defines["NUM_BLOCKS"] = OpenCLExpressionUtilities::intToString(context.getNumAtomBlocks());
263
264
265
266
        if (forceBufferPerAtomBlock)
            defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
267
268
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        cl::Program interactingBlocksProgram = context.createProgram(file, defines);
269
270
        findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
        findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
271
272
273
        findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer());
274
        findBlockBoundsKernel.setArg<cl::Buffer>(6, interactionCount->getDeviceBuffer());
275
        findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
276
277
278
279
280
        findInteractingBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
        findInteractingBlocksKernel.setArg<cl::Buffer>(3, blockCenter->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(4, blockBoundingBox->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
281
282
283
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
284
285
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
286
        if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
287
288
289
290
291
292
293
294
            findInteractionsWithinBlocksKernel = cl::Kernel(interactingBlocksProgram, "findInteractionsWithinBlocks");
            findInteractionsWithinBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(5, blockCenter->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(6, blockBoundingBox->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(8, interactionCount->getDeviceBuffer());
295
            findInteractionsWithinBlocksKernel.setArg(9, 128*sizeof(cl_uint), NULL);
296
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, interactingTiles->getSize());
297
        }
298
    }
299
300
}

301
302
303
304
305
306
307
308
309
int OpenCLNonbondedUtilities::findExclusionIndex(int x, int y, const vector<cl_uint>& exclusionIndices, const vector<cl_uint>& exclusionRowIndices) {
    int start = exclusionRowIndices[x];
    int end = exclusionRowIndices[x+1];
    for (int i = start; i < end; i++)
        if (exclusionIndices[i] == y)
            return i*OpenCLContext::TileSize;
    throw OpenMMException("Internal error: exclusion in unexpected tile");
}

310
311
312
void OpenCLNonbondedUtilities::prepareInteractions() {
    if (!useCutoff)
        return;
313
314
315

    // Compute the neighbor list.

316
317
    findBlockBoundsKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
    findBlockBoundsKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
318
    context.executeKernel(findBlockBoundsKernel, context.getNumAtoms());
319
320
    findInteractingBlocksKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
    findInteractingBlocksKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
321
322
    context.executeKernel(findInteractingBlocksKernel, context.getNumAtoms(), deviceIsCpu ? 1 : -1);
    if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
323
324
        findInteractionsWithinBlocksKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
        findInteractionsWithinBlocksKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
325
        context.executeKernel(findInteractionsWithinBlocksKernel, context.getNumAtoms(), 128);
326
    }
327
328
329
}

void OpenCLNonbondedUtilities::computeInteractions() {
330
    if (cutoff != -1.0) {
331
        if (useCutoff) {
332
333
            forceKernel.setArg<mm_float4>(10, context.getPeriodicBoxSize());
            forceKernel.setArg<mm_float4>(11, context.getInvPeriodicBoxSize());
334
        }
335
        context.executeKernel(forceKernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
336
    }
337
338
}

339
340
341
342
void OpenCLNonbondedUtilities::updateNeighborListSize() {
    if (!useCutoff)
        return;
    interactionCount->download();
343
    if (interactionCount->get(0) <= (unsigned int) interactingTiles->getSize())
344
345
346
347
348
349
350
351
352
353
354
        return;

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

    int newSize = (int) (1.2*interactionCount->get(0));
    int numTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (newSize > numTiles)
        newSize = numTiles;
    delete interactingTiles;
    interactingTiles = new OpenCLArray<mm_ushort2>(context, newSize, "interactingTiles");
355
356
    forceKernel.setArg<cl::Buffer>(8, interactingTiles->getDeviceBuffer());
    forceKernel.setArg<cl_uint>(12, newSize);
357
    findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
358
    findInteractingBlocksKernel.setArg<cl_uint>(9, newSize);
Peter Eastman's avatar
Peter Eastman committed
359
    if (context.getSIMDWidth() == 32 || deviceIsCpu) {
360
        delete interactionFlags;
Peter Eastman's avatar
Peter Eastman committed
361
        interactionFlags = new OpenCLArray<cl_uint>(context, deviceIsCpu ? 2*newSize : newSize, "interactionFlags");
362
        forceKernel.setArg<cl::Buffer>(13, interactionFlags->getDeviceBuffer());
363
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
364
365
366
367
368
369
        findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
        findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, newSize);
    }
}

370
371
372
373
374
void OpenCLNonbondedUtilities::setTileRange(int startTileIndex, int numTiles) {
    this->startTileIndex = startTileIndex;
    this->numTiles = numTiles;
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
375
376
    forceKernel.setArg<cl_uint>(6, startTileIndex);
    forceKernel.setArg<cl_uint>(7, startTileIndex+numTiles);
377
378
379
380
    if (useCutoff) {
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
    }
381
    else
382
        forceKernel.setArg<cl_uint>(8, numTiles);
383
384
}

385
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) const {
386
387
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
388
    const string suffixes[] = {"x", "y", "z", "w"};
389
    stringstream localData;
390
    int localDataSize = 0;
391
    for (int i = 0; i < (int) params.size(); i++) {
392
393
394
395
396
397
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
398
399
400
        localDataSize += params[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
401
    stringstream args;
402
    for (int i = 0; i < (int) params.size(); i++) {
403
        args << ", __global const ";
404
        args << params[i].getType();
405
        args << "* restrict global_";
406
407
        args << params[i].getName();
    }
408
    for (int i = 0; i < (int) arguments.size(); i++) {
409
410
411
412
413
414
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
415
                args << ", __global const ";
416
417
418
            else
                args << ", __constant ";
            args << arguments[i].getType();
419
            args << "* restrict ";
420
421
            args << arguments[i].getName();
        }
422
    }
423
424
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
425
    for (int i = 0; i < (int) params.size(); i++) {
426
        if (params[i].getNumComponents() == 1) {
427
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
428
429
430
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
431
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
432
        }
433
434
435
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
436
    for (int i = 0; i < (int) params.size(); i++) {
437
        if (params[i].getNumComponents() == 1) {
438
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
439
440
441
442
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
443
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
444
        }
445
446
447
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
448
    for (int i = 0; i < (int) params.size(); i++) {
449
450
451
452
453
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
454
        load1 << "[atom1];\n";
455
456
457
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
458
    for (int i = 0; i < (int) params.size(); i++) {
459
        if (params[i].getNumComponents() == 1) {
460
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
461
462
463
464
465
466
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
467
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
468
469
470
            }
            load2j<<");\n";
        }
471
    }
472
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
473
474
475
476
477
478
479
480
481
    map<string, string> defines;
    if (forceBufferPerAtomBlock)
        defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
482
483
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
484
    defines["FORCE_WORK_GROUP_SIZE"] = OpenCLExpressionUtilities::intToString(forceThreadBlockSize);
485
    defines["CUTOFF_SQUARED"] = OpenCLExpressionUtilities::doubleToString(cutoff*cutoff);
486
487
    defines["NUM_ATOMS"] = OpenCLExpressionUtilities::intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = OpenCLExpressionUtilities::intToString(context.getPaddedNumAtoms());
488
    defines["NUM_BLOCKS"] = OpenCLExpressionUtilities::intToString(context.getNumAtomBlocks());
489
490
    if ((localDataSize/4)%2 == 0)
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
491
492
493
494
495
496
497
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else if (context.getSIMDWidth() == 32)
        file = OpenCLKernelSources::nonbonded_nvidia;
    else
        file = OpenCLKernelSources::nonbonded_default;
498
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
499
500
501
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
502

503
    int index = 0;
504
505
506
507
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
508
509
510
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
511
512
    kernel.setArg<cl::Buffer>(index++, exclusionIndices->getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusionRowIndices->getDeviceBuffer());
513
514
    kernel.setArg<cl_uint>(index++, startTileIndex);
    kernel.setArg<cl_uint>(index++, startTileIndex+numTiles);
515
    if (useCutoff) {
516
517
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
518
        index += 2; // The periodic box size arguments are set when the kernel is executed.
519
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
Peter Eastman's avatar
Peter Eastman committed
520
        kernel.setArg<cl::Buffer>(index++, interactionFlags->getDeviceBuffer());
521
    }
522
    else {
523
        kernel.setArg<cl_uint>(index++, numTiles);
524
    }
525
    for (int i = 0; i < (int) params.size(); i++) {
526
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
527
    }
528
    for (int i = 0; i < (int) arguments.size(); i++) {
529
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
530
    }
531
    return kernel;
532
}