"...reference/src/SimTKReference/ReferenceLJCoulombIxn.cpp" did not exist on "ae4c6f96f9d551624813016e84b96a228a37dbec"
OpenCLNonbondedUtilities.cpp 26.1 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2011 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
29
#include "OpenCLKernelSources.h"
30
#include "OpenCLExpressionUtilities.h"
31
#include <map>
32
33
#include <set>
#include <utility>
34
35
36
37

using namespace OpenMM;
using namespace std;

38
OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), cutoff(-1.0), useCutoff(false), anyExclusions(false),
39
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusions(NULL), interactingTiles(NULL), interactionFlags(NULL),
40
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL) {
41
    // Decide how many thread blocks and force buffers to use.
42

43
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
44
    forceBufferPerAtomBlock = false;
45
46
47
48
49
50
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
51
        if (context.getSupports64BitGlobalAtomics()) {
52
            numForceThreadBlocks = 2*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
53
54
55
56
57
58
59
60
            forceThreadBlockSize = 256;
            numForceBuffers = 2;
        }
        else {
            numForceThreadBlocks = 4*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
            forceThreadBlockSize = 128;
            numForceBuffers = numForceThreadBlocks;
        }
61
    }
62
    else {
63
64
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = OpenCLContext::ThreadBlockSize;
65
        numForceBuffers = numForceThreadBlocks;
66
67
        if (numForceBuffers >= context.getNumAtomBlocks()) {
            // For small systems, it is more efficient to have one force buffer per block of 32 atoms instead of one per warp.
68

69
70
71
            forceBufferPerAtomBlock = true;
            numForceBuffers = context.getNumAtomBlocks();
        }
72
    }
73
74
75
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
76
77
78
79
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
80
81
    if (exclusions != NULL)
        delete exclusions;
82
83
84
85
86
87
88
89
90
91
    if (interactingTiles != NULL)
        delete interactingTiles;
    if (interactionFlags != NULL)
        delete interactionFlags;
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
92
93
}

94
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel) {
95
96
97
98
99
100
101
    if (cutoff != -1.0) {
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
        if (cutoffDistance != cutoff)
            throw OpenMMException("All Forces must use the same cutoff distance");
102
    }
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    if (usesExclusions)
        requestExclusions(exclusionList);
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
    cutoff = cutoffDistance;
    kernelSource += kernel+"\n";
}

void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
}

void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

void OpenCLNonbondedUtilities::requestExclusions(const vector<vector<int> >& exclusionList) {
    if (anyExclusions) {
121
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
122
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
123
124
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
125
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
126
127
128
129
130
131
                if (exclusionList[i][j] != atomExclusions[i][j])
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
132
    else {
133
        atomExclusions = exclusionList;
134
135
        anyExclusions = true;
    }
136
137
138
139
140
141
}

void OpenCLNonbondedUtilities::initialize(const System& system) {
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
    
142
143
144
145
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
        
        atomExclusions.resize(context.getNumAtoms());
146
        for (int i = 0; i < (int) atomExclusions.size(); i++)
147
148
149
            atomExclusions[i].push_back(i);
    }

150
151
152
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
153
154
155
156
157
    int totalTiles = numAtomBlocks*(numAtomBlocks+1)/2;
    int numContexts = context.getPlatformData().contexts.size();
    startTileIndex = context.getContextIndex()*totalTiles/numContexts;
    int endTileIndex = (context.getContextIndex()+1)*totalTiles/numContexts;
    numTiles = endTileIndex-startTileIndex;
158
159
160

    // Build a list of indices for the tiles with exclusions.

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
    if (context.getPaddedNumAtoms() > context.getNumAtoms()) {
        for (int i = 0; i < numAtomBlocks; ++i)
            tilesWithExclusions.insert(make_pair(numAtomBlocks-1, i));
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
    int currentRow = 0;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        while (iter->first != currentRow)
            exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
        exclusionIndicesVec.push_back(iter->second);
    }
    exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
    exclusionIndices = new OpenCLArray<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = new OpenCLArray<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
187
188
189

    // Record the exclusion data.

190
    exclusions = new OpenCLArray<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
191
192
193
194
195
196
197
198
199
200
201
    vector<cl_uint> exclusionVec(exclusions->getSize());
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
202
203
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
204
205
            }
            else {
206
207
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
208
209
210
211
212
213
214
215
216
217
218
219
220
            }
        }
    }

    // Mark all interactions that involve a padding atom as being excluded.

    for (int atom1 = context.getNumAtoms(); atom1 < context.getPaddedNumAtoms(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int atom2 = 0; atom2 < context.getPaddedNumAtoms(); ++atom2) {
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x >= y) {
221
222
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
223
224
            }
            if (y >= x) {
225
226
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
227
228
229
230
231
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
232
233
234
235

    // Create data structures for the neighbor list.

    if (useCutoff) {
236
237
238
239
240
241
242
        // Select a size for the arrays that hold the neighbor list.  This estimate is intentionally very
        // high, because if it ever is too small, we have to fall back to the N^2 algorithm.

        mm_float4 boxSize = context.getPeriodicBoxSize();
        int maxInteractingTiles = (int) (numTiles*(cutoff/boxSize.x+cutoff/boxSize.y+cutoff/boxSize.z));
        if (maxInteractingTiles > numTiles)
            maxInteractingTiles = numTiles;
243
244
        if (maxInteractingTiles < 1)
            maxInteractingTiles = 1;
245
        interactingTiles = new OpenCLArray<mm_ushort2>(context, maxInteractingTiles, "interactingTiles");
246
        interactionFlags = new OpenCLArray<cl_uint>(context, context.getSIMDWidth() == 32 ? maxInteractingTiles : (deviceIsCpu ? 2*maxInteractingTiles : 1), "interactionFlags");
247
        interactionCount = new OpenCLArray<cl_uint>(context, 1, "interactionCount", true);
248
249
        blockCenter = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockCenter");
        blockBoundingBox = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockBoundingBox");
250
251
        interactionCount->set(0, 0);
        interactionCount->upload();
252
    }
253
254
255

    // Create kernels.

256
    forceKernel = createInteractionKernel(kernelSource, parameters, arguments, true, true);
257
258
    if (useCutoff) {
        map<string, string> defines;
259
        defines["NUM_BLOCKS"] = OpenCLExpressionUtilities::intToString(context.getNumAtomBlocks());
260
261
262
263
        if (forceBufferPerAtomBlock)
            defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
264
265
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        cl::Program interactingBlocksProgram = context.createProgram(file, defines);
266
267
        findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
        findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
268
269
270
        findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer());
271
        findBlockBoundsKernel.setArg<cl::Buffer>(6, interactionCount->getDeviceBuffer());
272
        findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
273
274
275
276
277
        findInteractingBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
        findInteractingBlocksKernel.setArg<cl::Buffer>(3, blockCenter->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(4, blockBoundingBox->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
278
279
280
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
281
282
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
283
        if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
284
285
286
287
288
289
290
291
            findInteractionsWithinBlocksKernel = cl::Kernel(interactingBlocksProgram, "findInteractionsWithinBlocks");
            findInteractionsWithinBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(5, blockCenter->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(6, blockBoundingBox->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(8, interactionCount->getDeviceBuffer());
292
            findInteractionsWithinBlocksKernel.setArg(9, 128*sizeof(cl_uint), NULL);
293
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, interactingTiles->getSize());
294
        }
295
    }
296
297
}

298
299
300
301
302
303
304
305
306
int OpenCLNonbondedUtilities::findExclusionIndex(int x, int y, const vector<cl_uint>& exclusionIndices, const vector<cl_uint>& exclusionRowIndices) {
    int start = exclusionRowIndices[x];
    int end = exclusionRowIndices[x+1];
    for (int i = start; i < end; i++)
        if (exclusionIndices[i] == y)
            return i*OpenCLContext::TileSize;
    throw OpenMMException("Internal error: exclusion in unexpected tile");
}

307
308
309
void OpenCLNonbondedUtilities::prepareInteractions() {
    if (!useCutoff)
        return;
310
311
312

    // Compute the neighbor list.

313
314
    findBlockBoundsKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
    findBlockBoundsKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
315
    context.executeKernel(findBlockBoundsKernel, context.getNumAtoms());
316
317
    findInteractingBlocksKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
    findInteractingBlocksKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
318
319
    context.executeKernel(findInteractingBlocksKernel, context.getNumAtoms(), deviceIsCpu ? 1 : -1);
    if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
320
321
        findInteractionsWithinBlocksKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
        findInteractionsWithinBlocksKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
322
        context.executeKernel(findInteractionsWithinBlocksKernel, context.getNumAtoms(), 128);
323
    }
324
325
326
}

void OpenCLNonbondedUtilities::computeInteractions() {
327
    if (cutoff != -1.0) {
328
        if (useCutoff) {
329
330
            forceKernel.setArg<mm_float4>(12, context.getPeriodicBoxSize());
            forceKernel.setArg<mm_float4>(13, context.getInvPeriodicBoxSize());
331
        }
332
        context.executeKernel(forceKernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
333
    }
334
335
}

336
337
338
339
void OpenCLNonbondedUtilities::updateNeighborListSize() {
    if (!useCutoff)
        return;
    interactionCount->download();
340
    if (interactionCount->get(0) <= (unsigned int) interactingTiles->getSize())
341
342
343
344
345
346
347
348
349
350
351
        return;

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

    int newSize = (int) (1.2*interactionCount->get(0));
    int numTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (newSize > numTiles)
        newSize = numTiles;
    delete interactingTiles;
    interactingTiles = new OpenCLArray<mm_ushort2>(context, newSize, "interactingTiles");
352
353
    forceKernel.setArg<cl::Buffer>(10, interactingTiles->getDeviceBuffer());
    forceKernel.setArg<cl_uint>(14, newSize);
354
    findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
355
    findInteractingBlocksKernel.setArg<cl_uint>(9, newSize);
Peter Eastman's avatar
Peter Eastman committed
356
    if (context.getSIMDWidth() == 32 || deviceIsCpu) {
357
        delete interactionFlags;
Peter Eastman's avatar
Peter Eastman committed
358
        interactionFlags = new OpenCLArray<cl_uint>(context, deviceIsCpu ? 2*newSize : newSize, "interactionFlags");
359
        forceKernel.setArg<cl::Buffer>(15, interactionFlags->getDeviceBuffer());
360
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
361
362
363
364
365
366
        findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
        findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, newSize);
    }
}

367
368
369
370
371
372
373
374
375
376
377
378
379
void OpenCLNonbondedUtilities::setTileRange(int startTileIndex, int numTiles) {
    this->startTileIndex = startTileIndex;
    this->numTiles = numTiles;
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
    forceKernel.setArg<cl_uint>(8, startTileIndex);
    forceKernel.setArg<cl_uint>(9, startTileIndex+numTiles);
    if (useCutoff) {
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
    }
}

380
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) const {
381
382
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
383
    int localDataSize = 7*sizeof(cl_float);
384
    const string suffixes[] = {"x", "y", "z", "w"};
385
386
    stringstream localData;
    for (int i = 0; i < (int) params.size(); i++) {
387
388
389
390
391
392
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
393
394
395
396
397
398
399
        localDataSize += params[i].getSize();
    }
    if ((localDataSize/4)%2 == 0) {
        localData << "float padding;\n";
        localDataSize += 4;
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
400
    stringstream args;
401
    for (int i = 0; i < (int) params.size(); i++) {
402
403
404
405
406
        args << ", __global ";
        args << params[i].getType();
        args << "* global_";
        args << params[i].getName();
    }
407
    for (int i = 0; i < (int) arguments.size(); i++) {
408
409
410
411
412
413
414
415
416
417
418
419
420
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
                args << ", __global ";
            else
                args << ", __constant ";
            args << arguments[i].getType();
            args << "* ";
            args << arguments[i].getName();
        }
421
    }
422
423
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
424
    for (int i = 0; i < (int) params.size(); i++) {
425
        if (params[i].getNumComponents() == 1) {
426
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
427
428
429
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
430
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
431
        }
432
433
434
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
435
    for (int i = 0; i < (int) params.size(); i++) {
436
        if (params[i].getNumComponents() == 1) {
437
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
438
439
440
441
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
442
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
443
        }
444
445
446
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
447
    for (int i = 0; i < (int) params.size(); i++) {
448
449
450
451
452
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
453
        load1 << "[atom1];\n";
454
455
456
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
457
    for (int i = 0; i < (int) params.size(); i++) {
458
        if (params[i].getNumComponents() == 1) {
459
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
460
461
462
463
464
465
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
466
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
467
468
469
            }
            load2j<<");\n";
        }
470
    }
471
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
472
473
474
475
476
477
478
479
480
    map<string, string> defines;
    if (forceBufferPerAtomBlock)
        defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
481
482
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
483
484
    if (context.getSIMDWidth() == 32)
        defines["WARPS_PER_GROUP"] = OpenCLExpressionUtilities::intToString(forceThreadBlockSize/OpenCLContext::TileSize);
485
    defines["CUTOFF_SQUARED"] = OpenCLExpressionUtilities::doubleToString(cutoff*cutoff);
486
487
    defines["NUM_ATOMS"] = OpenCLExpressionUtilities::intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = OpenCLExpressionUtilities::intToString(context.getPaddedNumAtoms());
488
    defines["NUM_BLOCKS"] = OpenCLExpressionUtilities::intToString(context.getNumAtomBlocks());
489
490
491
492
493
494
495
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else if (context.getSIMDWidth() == 32)
        file = OpenCLKernelSources::nonbonded_nvidia;
    else
        file = OpenCLKernelSources::nonbonded_default;
496
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
497
498
499
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
500

501
    int index = 0;
502
503
504
505
    if (context.getSupports64BitGlobalAtomics())
        kernel.setArg<cl::Memory>(index++, context.getLongForceBuffer().getDeviceBuffer());
    else
        kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
506
507
508
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
509
510
    kernel.setArg<cl::Buffer>(index++, exclusionIndices->getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusionRowIndices->getDeviceBuffer());
511
512
    kernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize*localDataSize : forceThreadBlockSize*localDataSize), NULL);
    kernel.setArg(index++, 4*forceThreadBlockSize*sizeof(cl_float), NULL);
513
514
    kernel.setArg<cl_uint>(index++, startTileIndex);
    kernel.setArg<cl_uint>(index++, startTileIndex+numTiles);
515
    if (useCutoff) {
516
517
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
518
        index += 2; // The periodic box size arguments are set when the kernel is executed.
519
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
Peter Eastman's avatar
Peter Eastman committed
520
        kernel.setArg<cl::Buffer>(index++, interactionFlags->getDeviceBuffer());
521
    }
522
    else {
523
        kernel.setArg<cl_uint>(index++, numTiles);
524
    }
525
    for (int i = 0; i < (int) params.size(); i++) {
526
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
527
    }
528
    for (int i = 0; i < (int) arguments.size(); i++) {
529
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
530
    }
531
    return kernel;
532
}