OpenCLNonbondedUtilities.cpp 26 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2011 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLNonbondedUtilities.h"
#include "OpenCLArray.h"
29
#include "OpenCLKernelSources.h"
30
#include "OpenCLExpressionUtilities.h"
31
#include <map>
32
33
#include <set>
#include <utility>
34
35
36
37
38

using namespace OpenMM;
using namespace std;

OpenCLNonbondedUtilities::OpenCLNonbondedUtilities(OpenCLContext& context) : context(context), cutoff(-1.0), useCutoff(false),
39
        numForceBuffers(0), exclusionIndices(NULL), exclusionRowIndices(NULL), exclusions(NULL), interactingTiles(NULL), interactionFlags(NULL),
40
41
        interactionCount(NULL), blockCenter(NULL), blockBoundingBox(NULL), forceBufferFlags(NULL) {
    // Decide how many thread blocks and force buffers to use.
42

43
    deviceIsCpu = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
44
    forceBufferPerAtomBlock = false;
45
46
47
48
49
50
51
52
53
54
    if (deviceIsCpu) {
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = 1;
        numForceBuffers = numForceThreadBlocks;
    }
    else if (context.getSIMDWidth() == 32) {
        numForceThreadBlocks = 2*context.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
        forceThreadBlockSize = 256;
        numForceBuffers = numForceThreadBlocks;
    }
55
    else {
56
57
58
        numForceThreadBlocks = context.getNumThreadBlocks();
        forceThreadBlockSize = OpenCLContext::ThreadBlockSize;
        numForceBuffers = numForceThreadBlocks*forceThreadBlockSize/OpenCLContext::TileSize;
59
60
        if (numForceBuffers >= context.getNumAtomBlocks()) {
            // For small systems, it is more efficient to have one force buffer per block of 32 atoms instead of one per warp.
61

62
63
64
            forceBufferPerAtomBlock = true;
            numForceBuffers = context.getNumAtomBlocks();
        }
65
    }
66
67
68
}

OpenCLNonbondedUtilities::~OpenCLNonbondedUtilities() {
69
70
71
72
    if (exclusionIndices != NULL)
        delete exclusionIndices;
    if (exclusionRowIndices != NULL)
        delete exclusionRowIndices;
73
74
    if (exclusions != NULL)
        delete exclusions;
75
76
77
78
79
80
81
82
83
84
    if (interactingTiles != NULL)
        delete interactingTiles;
    if (interactionFlags != NULL)
        delete interactionFlags;
    if (interactionCount != NULL)
        delete interactionCount;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
85
86
    if (forceBufferFlags != NULL)
        delete forceBufferFlags;
87
88
}

89
void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, bool usesExclusions, double cutoffDistance, const vector<vector<int> >& exclusionList, const string& kernel) {
90
91
92
93
94
95
96
    if (cutoff != -1.0) {
        if (usesCutoff != useCutoff)
            throw OpenMMException("All Forces must agree on whether to use a cutoff");
        if (usesPeriodic != usePeriodic)
            throw OpenMMException("All Forces must agree on whether to use periodic boundary conditions");
        if (cutoffDistance != cutoff)
            throw OpenMMException("All Forces must use the same cutoff distance");
97
98
    }
    if (usesExclusions && atomExclusions.size() != 0) {
99
        bool sameExclusions = (exclusionList.size() == atomExclusions.size());
100
        for (int i = 0; i < (int) exclusionList.size() && sameExclusions; i++) {
101
102
            if (exclusionList[i].size() != atomExclusions[i].size())
                sameExclusions = false;
103
            for (int j = 0; j < (int) exclusionList[i].size(); j++)
104
105
106
107
108
109
                if (exclusionList[i][j] != atomExclusions[i][j])
                    sameExclusions = false;
        }
        if (!sameExclusions)
            throw OpenMMException("All Forces must have identical exceptions");
    }
110
111
112
113
114
115
    useCutoff = usesCutoff;
    usePeriodic = usesPeriodic;
    cutoff = cutoffDistance;
    kernelSource += kernel+"\n";
    if (usesExclusions)
        atomExclusions = exclusionList;
116
117
}

118
119
void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
    parameters.push_back(parameter);
120
121
}

122
123
124
125
void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
    arguments.push_back(parameter);
}

126
127
128
129
void OpenCLNonbondedUtilities::initialize(const System& system) {
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
    
130
131
132
133
    if (atomExclusions.size() == 0) {
        // No exclusions were specifically requested, so just mark every atom as not interacting with itself.
        
        atomExclusions.resize(context.getNumAtoms());
134
        for (int i = 0; i < (int) atomExclusions.size(); i++)
135
136
137
            atomExclusions[i].push_back(i);
    }

138
139
140
    // Create the list of tiles.

    int numAtomBlocks = context.getNumAtomBlocks();
141
142
143
144
145
    int totalTiles = numAtomBlocks*(numAtomBlocks+1)/2;
    int numContexts = context.getPlatformData().contexts.size();
    startTileIndex = context.getContextIndex()*totalTiles/numContexts;
    int endTileIndex = (context.getContextIndex()+1)*totalTiles/numContexts;
    numTiles = endTileIndex-startTileIndex;
146
147
148

    // Build a list of indices for the tiles with exclusions.

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    set<pair<int, int> > tilesWithExclusions;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            tilesWithExclusions.insert(make_pair(max(x, y), min(x, y)));
        }
    }
    if (context.getPaddedNumAtoms() > context.getNumAtoms()) {
        for (int i = 0; i < numAtomBlocks; ++i)
            tilesWithExclusions.insert(make_pair(numAtomBlocks-1, i));
    }
    vector<cl_uint> exclusionRowIndicesVec(numAtomBlocks+1, 0);
    vector<cl_uint> exclusionIndicesVec;
    int currentRow = 0;
    for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) {
        while (iter->first != currentRow)
            exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
        exclusionIndicesVec.push_back(iter->second);
    }
    exclusionRowIndicesVec[++currentRow] = exclusionIndicesVec.size();
    exclusionIndices = new OpenCLArray<cl_uint>(context, exclusionIndicesVec.size(), "exclusionIndices");
    exclusionRowIndices = new OpenCLArray<cl_uint>(context, exclusionRowIndicesVec.size(), "exclusionRowIndices");
    exclusionIndices->upload(exclusionIndicesVec);
    exclusionRowIndices->upload(exclusionRowIndicesVec);
175
176
177

    // Record the exclusion data.

178
    exclusions = new OpenCLArray<cl_uint>(context, tilesWithExclusions.size()*OpenCLContext::TileSize, "exclusions");
179
180
181
182
183
184
185
186
187
188
189
    vector<cl_uint> exclusionVec(exclusions->getSize());
    for (int i = 0; i < exclusions->getSize(); ++i)
        exclusionVec[i] = 0xFFFFFFFF;
    for (int atom1 = 0; atom1 < (int) atomExclusions.size(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int j = 0; j < (int) atomExclusions[atom1].size(); ++j) {
            int atom2 = atomExclusions[atom1][j];
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x > y) {
190
191
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
192
193
            }
            else {
194
195
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
196
197
198
199
200
201
202
203
204
205
206
207
208
            }
        }
    }

    // Mark all interactions that involve a padding atom as being excluded.

    for (int atom1 = context.getNumAtoms(); atom1 < context.getPaddedNumAtoms(); ++atom1) {
        int x = atom1/OpenCLContext::TileSize;
        int offset1 = atom1-x*OpenCLContext::TileSize;
        for (int atom2 = 0; atom2 < context.getPaddedNumAtoms(); ++atom2) {
            int y = atom2/OpenCLContext::TileSize;
            int offset2 = atom2-y*OpenCLContext::TileSize;
            if (x >= y) {
209
210
                int index = findExclusionIndex(x, y, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset1] &= 0xFFFFFFFF-(1<<offset2);
211
212
            }
            if (y >= x) {
213
214
                int index = findExclusionIndex(y, x, exclusionIndicesVec, exclusionRowIndicesVec);
                exclusionVec[index+offset2] &= 0xFFFFFFFF-(1<<offset1);
215
216
217
218
219
            }
        }
    }
    atomExclusions.clear(); // We won't use this again, so free the memory it used
    exclusions->upload(exclusionVec);
220
221
222
223

    // Create data structures for the neighbor list.

    if (useCutoff) {
224
225
226
227
228
229
230
        // Select a size for the arrays that hold the neighbor list.  This estimate is intentionally very
        // high, because if it ever is too small, we have to fall back to the N^2 algorithm.

        mm_float4 boxSize = context.getPeriodicBoxSize();
        int maxInteractingTiles = (int) (numTiles*(cutoff/boxSize.x+cutoff/boxSize.y+cutoff/boxSize.z));
        if (maxInteractingTiles > numTiles)
            maxInteractingTiles = numTiles;
231
232
        if (maxInteractingTiles < 1)
            maxInteractingTiles = 1;
233
        interactingTiles = new OpenCLArray<mm_ushort2>(context, maxInteractingTiles, "interactingTiles");
234
        interactionFlags = new OpenCLArray<cl_uint>(context, context.getSIMDWidth() == 32 ? maxInteractingTiles : (deviceIsCpu ? 2*maxInteractingTiles : 1), "interactionFlags");
235
        interactionCount = new OpenCLArray<cl_uint>(context, 1, "interactionCount", true);
236
237
        blockCenter = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockCenter");
        blockBoundingBox = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockBoundingBox");
238
239
        interactionCount->set(0, 0);
        interactionCount->upload();
240
    }
241

242
243
244
245
246
247
    // Create the flags for reserving force buffers.
    
    forceBufferFlags = new OpenCLArray<cl_uint>(context, numAtomBlocks*numForceThreadBlocks, "forceBufferFlags", false);
    vector<cl_uint> forceBufferFlagsVec(forceBufferFlags->getSize(), 0);
    forceBufferFlags->upload(forceBufferFlagsVec);

248
249
    // Create kernels.

250
    forceKernel = createInteractionKernel(kernelSource, parameters, arguments, true, true);
251
252
    if (useCutoff) {
        map<string, string> defines;
253
        defines["NUM_BLOCKS"] = OpenCLExpressionUtilities::intToString(context.getNumAtomBlocks());
254
255
256
257
        if (forceBufferPerAtomBlock)
            defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
258
259
        string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
        cl::Program interactingBlocksProgram = context.createProgram(file, defines);
260
261
        findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
        findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
262
263
264
        findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer());
        findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer());
265
        findBlockBoundsKernel.setArg<cl::Buffer>(6, interactionCount->getDeviceBuffer());
266
        findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
267
268
269
270
271
        findInteractingBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
        findInteractingBlocksKernel.setArg<cl::Buffer>(3, blockCenter->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(4, blockBoundingBox->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactionCount->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
272
273
274
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl::Buffer>(8, context.getPosq().getDeviceBuffer());
        findInteractingBlocksKernel.setArg<cl_uint>(9, interactingTiles->getSize());
275
276
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
277
        if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
278
279
280
281
282
283
284
285
286
            findInteractionsWithinBlocksKernel = cl::Kernel(interactingBlocksProgram, "findInteractionsWithinBlocks");
            findInteractionsWithinBlocksKernel.setArg<cl_float>(0, (cl_float) (cutoff*cutoff));
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(5, blockCenter->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(6, blockBoundingBox->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(8, interactionCount->getDeviceBuffer());
            findInteractionsWithinBlocksKernel.setArg(9, OpenCLContext::ThreadBlockSize*sizeof(cl_uint), NULL);
287
            findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, interactingTiles->getSize());
288
        }
289
    }
290
291
}

292
293
294
295
296
297
298
299
300
int OpenCLNonbondedUtilities::findExclusionIndex(int x, int y, const vector<cl_uint>& exclusionIndices, const vector<cl_uint>& exclusionRowIndices) {
    int start = exclusionRowIndices[x];
    int end = exclusionRowIndices[x+1];
    for (int i = start; i < end; i++)
        if (exclusionIndices[i] == y)
            return i*OpenCLContext::TileSize;
    throw OpenMMException("Internal error: exclusion in unexpected tile");
}

301
302
303
void OpenCLNonbondedUtilities::prepareInteractions() {
    if (!useCutoff)
        return;
304
305
306

    // Compute the neighbor list.

307
308
    findBlockBoundsKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
    findBlockBoundsKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
309
    context.executeKernel(findBlockBoundsKernel, context.getNumAtoms());
310
311
    findInteractingBlocksKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
    findInteractingBlocksKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
312
313
    context.executeKernel(findInteractingBlocksKernel, context.getNumAtoms(), deviceIsCpu ? 1 : -1);
    if (context.getSIMDWidth() == 32 && !deviceIsCpu) {
314
315
        findInteractionsWithinBlocksKernel.setArg<mm_float4>(1, context.getPeriodicBoxSize());
        findInteractionsWithinBlocksKernel.setArg<mm_float4>(2, context.getInvPeriodicBoxSize());
Peter Eastman's avatar
Peter Eastman committed
316
        context.executeKernel(findInteractionsWithinBlocksKernel, context.getNumAtoms());
317
    }
318
319
320
}

void OpenCLNonbondedUtilities::computeInteractions() {
321
    if (cutoff != -1.0) {
322
        if (useCutoff) {
323
324
            forceKernel.setArg<mm_float4>(13, context.getPeriodicBoxSize());
            forceKernel.setArg<mm_float4>(14, context.getInvPeriodicBoxSize());
325
        }
326
        context.executeKernel(forceKernel, numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
327
    }
328
329
}

330
331
332
333
void OpenCLNonbondedUtilities::updateNeighborListSize() {
    if (!useCutoff)
        return;
    interactionCount->download();
334
    if (interactionCount->get(0) <= (unsigned int) interactingTiles->getSize())
335
336
337
338
339
340
341
342
343
344
345
        return;

    // The most recent timestep had too many interactions to fit in the arrays.  Make the arrays bigger to prevent
    // this from happening in the future.

    int newSize = (int) (1.2*interactionCount->get(0));
    int numTiles = context.getNumAtomBlocks()*(context.getNumAtomBlocks()+1)/2;
    if (newSize > numTiles)
        newSize = numTiles;
    delete interactingTiles;
    interactingTiles = new OpenCLArray<mm_ushort2>(context, newSize, "interactingTiles");
346
347
    forceKernel.setArg<cl::Buffer>(11, interactingTiles->getDeviceBuffer());
    forceKernel.setArg<cl_uint>(15, newSize);
348
    findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingTiles->getDeviceBuffer());
349
    findInteractingBlocksKernel.setArg<cl_uint>(9, newSize);
Peter Eastman's avatar
Peter Eastman committed
350
    if (context.getSIMDWidth() == 32 || deviceIsCpu) {
351
        delete interactionFlags;
Peter Eastman's avatar
Peter Eastman committed
352
        interactionFlags = new OpenCLArray<cl_uint>(context, deviceIsCpu ? 2*newSize : newSize, "interactionFlags");
353
        forceKernel.setArg<cl::Buffer>(16, interactionFlags->getDeviceBuffer());
354
        findInteractingBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
355
356
357
358
359
360
        findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(4, interactingTiles->getDeviceBuffer());
        findInteractionsWithinBlocksKernel.setArg<cl::Buffer>(7, interactionFlags->getDeviceBuffer());
        findInteractionsWithinBlocksKernel.setArg<cl_uint>(10, newSize);
    }
}

361
362
363
364
365
366
367
368
369
370
371
372
373
void OpenCLNonbondedUtilities::setTileRange(int startTileIndex, int numTiles) {
    this->startTileIndex = startTileIndex;
    this->numTiles = numTiles;
    if (cutoff == -1.0)
        return; // There are no nonbonded interactions in the System.
    forceKernel.setArg<cl_uint>(8, startTileIndex);
    forceKernel.setArg<cl_uint>(9, startTileIndex+numTiles);
    if (useCutoff) {
        findInteractingBlocksKernel.setArg<cl_uint>(10, startTileIndex);
        findInteractingBlocksKernel.setArg<cl_uint>(11, startTileIndex+numTiles);
    }
}

374
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric) const {
375
376
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = source;
377
    int localDataSize = 7*sizeof(cl_float);
378
    const string suffixes[] = {"x", "y", "z", "w"};
379
380
    stringstream localData;
    for (int i = 0; i < (int) params.size(); i++) {
381
382
383
384
385
386
        if (params[i].getNumComponents() == 1)
            localData<<params[i].getType()<<" "<<params[i].getName()<<";\n";
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
                localData<<params[i].getComponentType()<<" "<<params[i].getName()<<"_"<<suffixes[j]<<";\n";
        }
387
388
389
390
391
392
393
        localDataSize += params[i].getSize();
    }
    if ((localDataSize/4)%2 == 0) {
        localData << "float padding;\n";
        localDataSize += 4;
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
394
    stringstream args;
395
    for (int i = 0; i < (int) params.size(); i++) {
396
397
398
399
400
        args << ", __global ";
        args << params[i].getType();
        args << "* global_";
        args << params[i].getName();
    }
401
    for (int i = 0; i < (int) arguments.size(); i++) {
402
403
404
405
406
407
408
409
410
411
412
413
414
        if (arguments[i].getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
            args << ", __read_only image2d_t ";
            args << arguments[i].getName();
        }
        else {
            if ((arguments[i].getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0)
                args << ", __global ";
            else
                args << ", __constant ";
            args << arguments[i].getType();
            args << "* ";
            args << arguments[i].getName();
        }
415
    }
416
417
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream loadLocal1;
418
    for (int i = 0; i < (int) params.size(); i++) {
419
        if (params[i].getNumComponents() == 1) {
420
            loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<" = "<<params[i].getName()<<"1;\n";
421
422
423
        }
        else {
            for (int j = 0; j < params[i].getNumComponents(); ++j)
424
                loadLocal1<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = "<<params[i].getName()<<"1."<<suffixes[j]<<";\n";
425
        }
426
427
428
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
    stringstream loadLocal2;
429
    for (int i = 0; i < (int) params.size(); i++) {
430
        if (params[i].getNumComponents() == 1) {
431
            loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
432
433
434
435
        }
        else {
            loadLocal2<<params[i].getType()<<" temp_"<<params[i].getName()<<" = global_"<<params[i].getName()<<"[j];\n";
            for (int j = 0; j < params[i].getNumComponents(); ++j)
436
                loadLocal2<<"localData[localAtomIndex]."<<params[i].getName()<<"_"<<suffixes[j]<<" = temp_"<<params[i].getName()<<"."<<suffixes[j]<<";\n";
437
        }
438
439
440
    }
    replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
    stringstream load1;
441
    for (int i = 0; i < (int) params.size(); i++) {
442
443
444
445
446
        load1 << params[i].getType();
        load1 << " ";
        load1 << params[i].getName();
        load1 << "1 = global_";
        load1 << params[i].getName();
447
        load1 << "[atom1];\n";
448
449
450
    }
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream load2j;
451
    for (int i = 0; i < (int) params.size(); i++) {
452
        if (params[i].getNumComponents() == 1) {
453
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = localData[atom2]."<<params[i].getName()<<";\n";
454
455
456
457
458
459
        }
        else {
            load2j<<params[i].getType()<<" "<<params[i].getName()<<"2 = ("<<params[i].getType()<<") (";
            for (int j = 0; j < params[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2j<<", ";
460
                load2j<<"localData[atom2]."<<params[i].getName()<<"_"<<suffixes[j];
461
462
463
            }
            load2j<<");\n";
        }
464
    }
465
    replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
466
467
468
469
470
471
472
473
474
    map<string, string> defines;
    if (forceBufferPerAtomBlock)
        defines["USE_OUTPUT_BUFFER_PER_BLOCK"] = "1";
    if (useCutoff)
        defines["USE_CUTOFF"] = "1";
    if (usePeriodic)
        defines["USE_PERIODIC"] = "1";
    if (useExclusions)
        defines["USE_EXCLUSIONS"] = "1";
475
476
    if (isSymmetric)
        defines["USE_SYMMETRIC"] = "1";
477
478
    if (context.getSIMDWidth() == 32)
        defines["WARPS_PER_GROUP"] = OpenCLExpressionUtilities::intToString(forceThreadBlockSize/OpenCLContext::TileSize);
479
    defines["CUTOFF_SQUARED"] = OpenCLExpressionUtilities::doubleToString(cutoff*cutoff);
480
481
    defines["NUM_ATOMS"] = OpenCLExpressionUtilities::intToString(context.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = OpenCLExpressionUtilities::intToString(context.getPaddedNumAtoms());
482
    defines["NUM_BLOCKS"] = OpenCLExpressionUtilities::intToString(context.getNumAtomBlocks());
483
484
485
486
487
488
489
    string file;
    if (deviceIsCpu)
        file = OpenCLKernelSources::nonbonded_cpu;
    else if (context.getSIMDWidth() == 32)
        file = OpenCLKernelSources::nonbonded_nvidia;
    else
        file = OpenCLKernelSources::nonbonded_default;
490
    cl::Program program = context.createProgram(context.replaceStrings(file, replacements), defines);
491
492
493
    cl::Kernel kernel(program, "computeNonbonded");

    // Set arguments to the Kernel.
494

495
496
497
498
499
    int index = 0;
    kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
500
501
    kernel.setArg<cl::Buffer>(index++, exclusionIndices->getDeviceBuffer());
    kernel.setArg<cl::Buffer>(index++, exclusionRowIndices->getDeviceBuffer());
502
503
    kernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize*localDataSize : forceThreadBlockSize*localDataSize), NULL);
    kernel.setArg(index++, 4*forceThreadBlockSize*sizeof(cl_float), NULL);
504
505
    kernel.setArg<cl_uint>(index++, startTileIndex);
    kernel.setArg<cl_uint>(index++, startTileIndex+numTiles);
506
    kernel.setArg<cl::Buffer>(index++, forceBufferFlags->getDeviceBuffer());
507
    if (useCutoff) {
508
509
        kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(index++, interactionCount->getDeviceBuffer());
510
        index += 2; // The periodic box size arguments are set when the kernel is executed.
511
        kernel.setArg<cl_uint>(index++, interactingTiles->getSize());
Peter Eastman's avatar
Peter Eastman committed
512
        kernel.setArg<cl::Buffer>(index++, interactionFlags->getDeviceBuffer());
513
    }
514
    else {
515
        kernel.setArg<cl_uint>(index++, numTiles);
516
    }
517
    for (int i = 0; i < (int) params.size(); i++) {
518
        kernel.setArg<cl::Memory>(index++, params[i].getMemory());
519
    }
520
    for (int i = 0; i < (int) arguments.size(); i++) {
521
        kernel.setArg<cl::Memory>(index++, arguments[i].getMemory());
522
    }
523
    return kernel;
524
}