OpenCLContext.cpp 41.7 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2023 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
28
29
30
#ifdef WIN32
  #define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include <cmath>
31
32
#include "OpenCLContext.h"
#include "OpenCLArray.h"
Peter Eastman's avatar
Peter Eastman committed
33
#include "OpenCLBondedUtilities.h"
34
#include "OpenCLEvent.h"
35
#include "OpenCLForceInfo.h"
36
#include "OpenCLIntegrationUtilities.h"
37
#include "OpenCLKernelSources.h"
38
#include "OpenCLNonbondedUtilities.h"
39
40
#include "OpenCLProgram.h"
#include "openmm/common/ComputeArray.h"
41
#include "openmm/Platform.h"
42
#include "openmm/System.h"
43
#include "openmm/VirtualSite.h"
44
#include "openmm/internal/ContextImpl.h"
Peter Eastman's avatar
Peter Eastman committed
45
#include <algorithm>
46
47
#include <fstream>
#include <iostream>
48
#include <set>
49
#include <sstream>
50
#include <typeinfo>
51
52

using namespace OpenMM;
53
using namespace std;
54

55
56
57
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV 0x4000
#endif
58
59
60
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV 0x4001
#endif
61

62
63
64
const int OpenCLContext::ThreadBlockSize = 64;
const int OpenCLContext::TileSize = 32;

65
static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_info, size_t cb, void* user_data) {
66
67
68
    string skip = "OpenCL Build Warning : Compiler build log:";
    if (strncmp(errinfo, skip.c_str(), skip.length()) == 0)
        return; // OS X Lion insists on calling this for every build warning, even though they aren't errors.
69
70
71
    std::cerr << "OpenCL internal error: " << errinfo << std::endl;
}

72
73
74
75
76
77
78
79
static bool isSupported(cl::Platform platform) {
    string vendor = platform.getInfo<CL_PLATFORM_VENDOR>();
    return (vendor.find("NVIDIA") == 0 ||
            vendor.find("Advanced Micro Devices") == 0 ||
            vendor.find("Apple") == 0 ||
            vendor.find("Intel") == 0);
}

80
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, const string& precision, OpenCLPlatform::PlatformData& platformData, OpenCLContext* originalContext) :
81
        ComputeContext(system), platformData(platformData), numForceBuffers(0), hasAssignedPosqCharges(false),
82
        integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), pinnedBuffer(NULL) {
83
84
85
86
87
88
89
90
91
92
93
94
95
    if (precision == "single") {
        useDoublePrecision = false;
        useMixedPrecision = false;
    }
    else if (precision == "mixed") {
        useDoublePrecision = false;
        useMixedPrecision = true;
    }
    else if (precision == "double") {
        useDoublePrecision = true;
        useMixedPrecision = false;
    }
    else
96
        throw OpenMMException("Illegal value for Precision: "+precision);
97
    try {
98
        contextIndex = platformData.contexts.size();
99
100
        std::vector<cl::Platform> platforms;
        cl::Platform::get(&platforms);
101
102
        if (platformIndex < -1 || platformIndex >= (int) platforms.size())
            throw OpenMMException("Illegal value for OpenCLPlatformIndex: "+intToString(platformIndex));
103
104
        if (platforms.size() > 1 && platformIndex == -1 && deviceIndex != -1)
            throw OpenMMException("Specified DeviceIndex but not OpenCLPlatformIndex.  When multiple platforms are available, a platform index is needed to specify a device.");
Robert McGibbon's avatar
Robert McGibbon committed
105
        const int minThreadBlockSize = 32;
106

Robert McGibbon's avatar
Robert McGibbon committed
107
108
109
        int bestSpeed = -1;
        int bestDevice = -1;
        int bestPlatform = -1;
110
        bool bestSupported = false;
Robert McGibbon's avatar
Robert McGibbon committed
111
        for (int j = 0; j < platforms.size(); j++) {
112
113
            // If they supplied a valid platformIndex, we only look through that platform
            if (j != platformIndex && platformIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
114
115
                continue;

116
117
118
119
            // Always prefer a supported platform over an unsupported one.
            bool supported = isSupported(platforms[j]);
            if (!supported && bestSupported)
                continue;
Robert McGibbon's avatar
Robert McGibbon committed
120
121
            string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>();
            vector<cl::Device> devices;
122
123
124
125
126
127
128
            try {
                platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices);
            }
            catch (...) {
                // There are no devices available for this platform.
                continue;
            }
129
            if (deviceIndex < -1 || deviceIndex >= (int) devices.size())
130
                throw OpenMMException("Illegal value for DeviceIndex: "+intToString(deviceIndex));
Robert McGibbon's avatar
Robert McGibbon committed
131

132
            for (int i = 0; i < (int) devices.size(); i++) {
133
134
                // If they supplied a valid deviceIndex, we only look through that one
                if (i != deviceIndex && deviceIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
135
                    continue;
136
137
                if (platformVendor == "Apple" && (devices[i].getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU))
                    continue; // The CPU device on OS X won't work correctly.
138
139
140
141
142
                if (useMixedPrecision || useDoublePrecision) {
                    bool supportsDouble = (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
                    if (!supportsDouble)
                        continue; // This device does not support double precision.
                }
143
                int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0];
144
145
146
147
148
                int processingElementsPerComputeUnit = 8;
                if (devices[i].getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                    processingElementsPerComputeUnit = 1;
                }
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
149
150
151
152
                    cl_uint computeCapabilityMajor;
                    clGetDeviceInfo(devices[i](), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
                    processingElementsPerComputeUnit = (computeCapabilityMajor < 2 ? 8 : 32);
                }
153
154
155
156
157
158
159
160
161
162
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime (it may be an older runtime,
                    // or the CPU device) so still have to check for errors.
                    try {
                        processingElementsPerComputeUnit =
                            // AMD GPUs either have a single VLIW SIMD or multiple scalar SIMDs.
                            // The SIMD width is the number of threads the SIMD executes per cycle.
                            // This will be less than the wavefront width since it takes several
                            // cycles to execute the full wavefront.
                            // The SIMD instruction width is the VLIW instruction width (or 1 for scalar),
163
                            // this is the number of ALUs that can be executing per instruction per thread.
164
165
166
167
168
169
170
171
172
173
174
                            devices[i].getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_WIDTH_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_INSTRUCTION_WIDTH_AMD>();
                        // Just in case any of the queries return 0.
                        if (processingElementsPerComputeUnit <= 0)
                            processingElementsPerComputeUnit = 1;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the queries so use default.
                    }
                }
175
                int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*processingElementsPerComputeUnit*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
176
                if (maxSize >= minThreadBlockSize && (speed > bestSpeed || (supported && !bestSupported))) {
Robert McGibbon's avatar
Robert McGibbon committed
177
                    bestDevice = i;
178
                    bestSpeed = speed;
Robert McGibbon's avatar
Robert McGibbon committed
179
                    bestPlatform = j;
180
                    bestSupported = supported;
181
                }
182
            }
183
        }
Robert McGibbon's avatar
Robert McGibbon committed
184
185
186
187
188

        if (bestPlatform == -1)
            throw OpenMMException("No compatible OpenCL platform is available");

        if (bestDevice == -1)
189
            throw OpenMMException("No compatible OpenCL device is available");
190
191
192
        
        if (!bestSupported)
            cout << "WARNING: Using an unsupported OpenCL implementation.  Results may be incorrect." << endl;
Robert McGibbon's avatar
Robert McGibbon committed
193
194
195

        vector<cl::Device> devices;
        platforms[bestPlatform].getDevices(CL_DEVICE_TYPE_ALL, &devices);
Robert McGibbon's avatar
Robert McGibbon committed
196
        string platformVendor = platforms[bestPlatform].getInfo<CL_PLATFORM_VENDOR>();
Robert McGibbon's avatar
Robert McGibbon committed
197
198
199
        device = devices[bestDevice];

        this->deviceIndex = bestDevice;
Robert McGibbon's avatar
Robert McGibbon committed
200
        this->platformIndex = bestPlatform;
201
        if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize)
202
            throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
203
        compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize);
Peter Eastman's avatar
Peter Eastman committed
204
        if (platformVendor.size() >= 5 && platformVendor.substr(0, 5) == "Intel")
205
206
            defaultOptimizationOptions = "";
        else
207
            defaultOptimizationOptions = "-cl-mad-enable -cl-no-signed-zeros";
208
        supports64BitGlobalAtomics = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_int64_base_atomics") != string::npos);
209
        supportsDoublePrecision = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
210
211
        if ((useDoublePrecision || useMixedPrecision) && !supportsDoublePrecision)
            throw OpenMMException("This device does not support double precision");
212
        string vendor = device.getInfo<CL_DEVICE_VENDOR>();
213
        int numThreadBlocksPerComputeUnit = 6;
214
215
216
217
218
      
        if (vendor.size() >= 5 && vendor.substr(0, 5) == "Apple") {
            simdWidth = 32;
        }
        else if (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA") {
219
            compilationDefines["WARPS_ARE_ATOMIC"] = "";
220
            simdWidth = 32;
221
222
            if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
                // Compute level 1.2 and later Nvidia GPUs support 64 bit atomics, even though they don't list the
223
224
                // proper extension as supported.  We only use them on compute level 2.0 or later, since they're very
                // slow on earlier GPUs.
225

226
                cl_uint computeCapabilityMajor;
227
                clGetDeviceInfo(device(), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
228
                if (computeCapabilityMajor > 1)
229
                    supports64BitGlobalAtomics = true;
230
231
232
233
234
235
236
                if (computeCapabilityMajor == 5) {
                    // Workaround for a bug in Maxwell on CUDA 6.x.

                    string platformVersion = platforms[bestPlatform].getInfo<CL_PLATFORM_VERSION>();
                    if (platformVersion.find("CUDA 6") != string::npos)
                        supports64BitGlobalAtomics = false;
                }
237
            }
238
        }
239
        else if (vendor.size() >= 28 && vendor.substr(0, 28) == "Advanced Micro Devices, Inc.") {
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            if (device.getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                /// \todo Is 6 a good value for the OpenCL CPU device?
                // numThreadBlocksPerComputeUnit = ?;
                simdWidth = 1;
            }
            else {
                bool amdPostSdk2_4 = false;
                // Default to 1 which will use the default kernels.
                simdWidth = 1;
                if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime so still have to
                    // check for errors.
                    try {
                        // Must catch cl:Error as will fail if runtime does not support queries.

                        cl_uint simdPerComputeUnit = device.getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>();
256
257
                        simdWidth = device.getInfo<CL_DEVICE_WAVEFRONT_WIDTH_AMD>();

258
259
260
                        // If the GPU has multiple SIMDs per compute unit then it is uses the scalar instruction
                        // set instead of the VLIW instruction set. It therefore needs more thread blocks per
                        // compute unit to hide memory latency.
Peter Eastman's avatar
Peter Eastman committed
261
262
263
264
265
266
                        if (simdPerComputeUnit > 1) {
                            if (simdWidth == 32)
                                numThreadBlocksPerComputeUnit = 6*simdPerComputeUnit; // Navi seems to like more thread blocks than older GPUs
                            else
                                numThreadBlocksPerComputeUnit = 4*simdPerComputeUnit;
                        }
267
268
269
270
271
272
273
274
275
276
277
278
279

                        // If the queries are supported then must be newer than SDK 2.4.
                        amdPostSdk2_4 = true;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the query so is unlikely to be the newer scalar GPU.
                        // Stay with the default simdWidth and numThreadBlocksPerComputeUnit.
                    }
                }
                // AMD APP SDK 2.4 has a performance problem with atomics. Enable the work around. This is fixed after SDK 2.4.
                if (!amdPostSdk2_4)
                    compilationDefines["AMD_ATOMIC_WORK_AROUND"] = "";
            }
280
        }
281
282
        else
            simdWidth = 1;
283
        if (supports64BitGlobalAtomics)
284
            compilationDefines["SUPPORTS_64_BIT_ATOMICS"] = "";
285
286
        if (supportsDoublePrecision)
            compilationDefines["SUPPORTS_DOUBLE_PRECISION"] = "";
287
        if (simdWidth >= 32)
Peter Eastman's avatar
Peter Eastman committed
288
            compilationDefines["SYNC_WARPS"] = "mem_fence(CLK_LOCAL_MEM_FENCE)";
289
290
        else
            compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)";
291
292
        vector<cl::Device> contextDevices;
        contextDevices.push_back(device);
Robert McGibbon's avatar
Robert McGibbon committed
293
        cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
294
295
296
297
298
299
300
301
        if (originalContext == NULL) {
            context = cl::Context(contextDevices, cprops, errorCallback);
            defaultQueue = cl::CommandQueue(context, device);
        }
        else {
            context = originalContext->context;
            defaultQueue = originalContext->defaultQueue;
        }
302
        currentQueue = defaultQueue;
Peter Eastman's avatar
Peter Eastman committed
303
304
        numAtoms = system.getNumParticles();
        paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
305
        numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
306
        numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
307
        if (useDoublePrecision) {
peastman's avatar
peastman committed
308
309
            posq.initialize<mm_double4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_double4>(*this, paddedNumAtoms, "velm");
310
311
            compilationDefines["USE_DOUBLE_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_double4";
312
313
314
            compilationDefines["make_real2"] = "make_double2";
            compilationDefines["make_real3"] = "make_double3";
            compilationDefines["make_real4"] = "make_double4";
315
            compilationDefines["convert_mixed4"] = "convert_double4";
316
317
318
            compilationDefines["make_mixed2"] = "make_double2";
            compilationDefines["make_mixed3"] = "make_double3";
            compilationDefines["make_mixed4"] = "make_double4";
319
320
        }
        else if (useMixedPrecision) {
peastman's avatar
peastman committed
321
322
323
            posq.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            posqCorrection.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_double4>(*this, paddedNumAtoms, "velm");
324
325
            compilationDefines["USE_MIXED_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_float4";
326
327
328
            compilationDefines["make_real2"] = "make_float2";
            compilationDefines["make_real3"] = "make_float3";
            compilationDefines["make_real4"] = "make_float4";
329
            compilationDefines["convert_mixed4"] = "convert_double4";
330
331
332
            compilationDefines["make_mixed2"] = "make_double2";
            compilationDefines["make_mixed3"] = "make_double3";
            compilationDefines["make_mixed4"] = "make_double4";
333
334
        }
        else {
peastman's avatar
peastman committed
335
336
            posq.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_float4>(*this, paddedNumAtoms, "velm");
337
            compilationDefines["convert_real4"] = "convert_float4";
338
339
340
            compilationDefines["make_real2"] = "make_float2";
            compilationDefines["make_real3"] = "make_float3";
            compilationDefines["make_real4"] = "make_float4";
341
            compilationDefines["convert_mixed4"] = "convert_float4";
342
343
344
            compilationDefines["make_mixed2"] = "make_float2";
            compilationDefines["make_mixed3"] = "make_float3";
            compilationDefines["make_mixed4"] = "make_float4";
345
        }
346
        longForceBuffer.initialize<cl_long>(*this, 3*paddedNumAtoms, "longForceBuffer");
347
        posCellOffsets.resize(paddedNumAtoms, mm_int4(0, 0, 0, 0));
348
349
350
351
352
        atomIndexDevice.initialize<cl_int>(*this, paddedNumAtoms, "atomIndexDevice");
        atomIndex.resize(paddedNumAtoms);
        for (int i = 0; i < paddedNumAtoms; ++i)
            atomIndex[i] = i;
        atomIndexDevice.upload(atomIndex);
353
354
355
356
357
    }
    catch (cl::Error err) {
        std::stringstream str;
        str<<"Error initializing context: "<<err.what()<<" ("<<err.err()<<")";
        throw OpenMMException(str.str());
358
    }
359
360
361

    // Create utility kernels that are used in multiple places.

Peter Eastman's avatar
Peter Eastman committed
362
    cl::Program utilities = createProgram(OpenCLKernelSources::utilities);
363
    clearBufferKernel = cl::Kernel(utilities, "clearBuffer");
364
365
366
    clearTwoBuffersKernel = cl::Kernel(utilities, "clearTwoBuffers");
    clearThreeBuffersKernel = cl::Kernel(utilities, "clearThreeBuffers");
    clearFourBuffersKernel = cl::Kernel(utilities, "clearFourBuffers");
367
368
    clearFiveBuffersKernel = cl::Kernel(utilities, "clearFiveBuffers");
    clearSixBuffersKernel = cl::Kernel(utilities, "clearSixBuffers");
369
    reduceReal4Kernel = cl::Kernel(utilities, "reduceReal4Buffer");
370
    reduceForcesKernel = cl::Kernel(utilities, "reduceForces");
Peter Eastman's avatar
Peter Eastman committed
371
    reduceEnergyKernel = cl::Kernel(utilities, "reduceEnergy");
372
    setChargesKernel = cl::Kernel(utilities, "setCharges");
373
374
375

    // Decide whether native_sqrt(), native_rsqrt(), and native_recip() are sufficiently accurate to use.

376
377
378
379
380
    if (!useDoublePrecision) {
        cl::Kernel accuracyKernel(utilities, "determineNativeAccuracy");
        OpenCLArray valuesArray(*this, 20, sizeof(mm_float8), "values");
        vector<mm_float8> values(valuesArray.getSize());
        float nextValue = 1e-4f;
peastman's avatar
peastman committed
381
382
        for (auto& val : values) {
            val.s0 = nextValue;
383
384
385
386
387
388
389
390
            nextValue *= (float) M_PI;
        }
        valuesArray.upload(values);
        accuracyKernel.setArg<cl::Buffer>(0, valuesArray.getDeviceBuffer());
        accuracyKernel.setArg<cl_int>(1, values.size());
        executeKernel(accuracyKernel, values.size());
        valuesArray.download(values);
        double maxSqrtError = 0.0, maxRsqrtError = 0.0, maxRecipError = 0.0, maxExpError = 0.0, maxLogError = 0.0;
peastman's avatar
peastman committed
391
392
        for (auto& val : values) {
            double v = val.s0;
393
            double correctSqrt = sqrt(v);
peastman's avatar
peastman committed
394
395
396
397
398
            maxSqrtError = max(maxSqrtError, fabs(correctSqrt-val.s1)/correctSqrt);
            maxRsqrtError = max(maxRsqrtError, fabs(1.0/correctSqrt-val.s2)*correctSqrt);
            maxRecipError = max(maxRecipError, fabs(1.0/v-val.s3)/val.s3);
            maxExpError = max(maxExpError, fabs(exp(v)-val.s4)/val.s4);
            maxLogError = max(maxLogError, fabs(log(v)-val.s5)/val.s5);
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        }
        compilationDefines["SQRT"] = (maxSqrtError < 1e-6) ? "native_sqrt" : "sqrt";
        compilationDefines["RSQRT"] = (maxRsqrtError < 1e-6) ? "native_rsqrt" : "rsqrt";
        compilationDefines["RECIP"] = (maxRecipError < 1e-6) ? "native_recip" : "1.0f/";
        compilationDefines["EXP"] = (maxExpError < 1e-6) ? "native_exp" : "exp";
        compilationDefines["LOG"] = (maxLogError < 1e-6) ? "native_log" : "log";
    }
    else {
        compilationDefines["SQRT"] = "sqrt";
        compilationDefines["RSQRT"] = "rsqrt";
        compilationDefines["RECIP"] = "1.0/";
        compilationDefines["EXP"] = "exp";
        compilationDefines["LOG"] = "log";
    }
413
414
415
416
417
418
419
    compilationDefines["POW"] = "pow";
    compilationDefines["COS"] = "cos";
    compilationDefines["SIN"] = "sin";
    compilationDefines["TAN"] = "tan";
    compilationDefines["ACOS"] = "acos";
    compilationDefines["ASIN"] = "asin";
    compilationDefines["ATAN"] = "atan";
420
421
    compilationDefines["ERF"] = "erf";
    compilationDefines["ERFC"] = "erfc";
422

423
    // Set defines for applying periodic boundary conditions.
424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    Vec3 boxVectors[3];
    system.getDefaultPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
    boxIsTriclinic = (boxVectors[0][1] != 0.0 || boxVectors[0][2] != 0.0 ||
                      boxVectors[1][0] != 0.0 || boxVectors[1][2] != 0.0 ||
                      boxVectors[2][0] != 0.0 || boxVectors[2][1] != 0.0);
    if (boxIsTriclinic) {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "{"
            "real scale3 = floor(delta.z*invPeriodicBoxSize.z+0.5f); \\\n"
            "delta.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(delta.y*invPeriodicBoxSize.y+0.5f); \\\n"
            "delta.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(delta.x*invPeriodicBoxSize.x+0.5f); \\\n"
            "delta.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "{"
            "real scale3 = floor(pos.z*invPeriodicBoxSize.z); \\\n"
            "pos.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(pos.y*invPeriodicBoxSize.y); \\\n"
            "pos.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(pos.x*invPeriodicBoxSize.x); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "real scale3 = floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f); \\\n"
            "pos.x -= scale3*periodicBoxVecZ.x; \\\n"
            "pos.y -= scale3*periodicBoxVecZ.y; \\\n"
            "pos.z -= scale3*periodicBoxVecZ.z; \\\n"
            "real scale2 = floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f); \\\n"
            "pos.x -= scale2*periodicBoxVecY.x; \\\n"
            "pos.y -= scale2*periodicBoxVecY.y; \\\n"
            "real scale1 = floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
    }
    else {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "pos.xyz -= floor(pos.xyz*invPeriodicBoxSize.xyz)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "pos.x -= floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; \\\n"
            "pos.y -= floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y; \\\n"
            "pos.z -= floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;}";
    }

471
    // Create utilities objects.
472

473
474
    bonded = new OpenCLBondedUtilities(*this);
    nonbonded = new OpenCLNonbondedUtilities(*this);
Peter Eastman's avatar
Peter Eastman committed
475
    integration = new OpenCLIntegrationUtilities(*this, system);
476
    expression = new OpenCLExpressionUtilities(*this);
477
478
479
}

OpenCLContext::~OpenCLContext() {
peastman's avatar
peastman committed
480
481
482
483
484
485
486
487
    for (auto force : forces)
        delete force;
    for (auto listener : reorderListeners)
        delete listener;
    for (auto computation : preComputations)
        delete computation;
    for (auto computation : postComputations)
        delete computation;
488
489
    if (pinnedBuffer != NULL)
        delete pinnedBuffer;
490
491
    if (integration != NULL)
        delete integration;
492
493
    if (expression != NULL)
        delete expression;
Peter Eastman's avatar
Peter Eastman committed
494
495
    if (bonded != NULL)
        delete bonded;
496
497
    if (nonbonded != NULL)
        delete nonbonded;
498
499
}

500
void OpenCLContext::initialize() {
Peter Eastman's avatar
Peter Eastman committed
501
    bonded->initialize(system);
502
    numForceBuffers = std::max(numForceBuffers, (int) platformData.contexts.size());
503
    int energyBufferSize = max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers());
504
    int numComputeUnits = device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
505
    if (useDoublePrecision) {
peastman's avatar
peastman committed
506
507
508
        forceBuffers.initialize<mm_double4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_double4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_double>(*this, energyBufferSize, "energyBuffer");
509
        energySum.initialize<cl_double>(*this, numComputeUnits, "energySum");
510
    }
Peter Eastman's avatar
Peter Eastman committed
511
    else if (useMixedPrecision) {
peastman's avatar
peastman committed
512
513
514
        forceBuffers.initialize<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_float4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_double>(*this, energyBufferSize, "energyBuffer");
515
        energySum.initialize<cl_double>(*this, numComputeUnits, "energySum");
Peter Eastman's avatar
Peter Eastman committed
516
517
    }
    else {
peastman's avatar
peastman committed
518
519
520
        forceBuffers.initialize<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_float4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_float>(*this, energyBufferSize, "energyBuffer");
521
        energySum.initialize<cl_float>(*this, numComputeUnits, "energySum");
522
    }
523
524
525
526
    reduceForcesKernel.setArg<cl::Buffer>(0, longForceBuffer.getDeviceBuffer());
    reduceForcesKernel.setArg<cl::Buffer>(1, forceBuffers.getDeviceBuffer());
    reduceForcesKernel.setArg<cl_int>(2, paddedNumAtoms);
    reduceForcesKernel.setArg<cl_int>(3, numForceBuffers);
527
    addAutoclearBuffer(longForceBuffer);
peastman's avatar
peastman committed
528
529
    addAutoclearBuffer(forceBuffers);
    addAutoclearBuffer(energyBuffer);
530
531
532
    int numEnergyParamDerivs = energyParamDerivNames.size();
    if (numEnergyParamDerivs > 0) {
        if (useDoublePrecision || useMixedPrecision)
peastman's avatar
peastman committed
533
            energyParamDerivBuffer.initialize<cl_double>(*this, numEnergyParamDerivs*energyBufferSize, "energyParamDerivBuffer");
534
        else
peastman's avatar
peastman committed
535
536
            energyParamDerivBuffer.initialize<cl_float>(*this, numEnergyParamDerivs*energyBufferSize, "energyParamDerivBuffer");
        addAutoclearBuffer(energyParamDerivBuffer);
537
    }
538
    int bufferBytes = max(max((int) velm.getSize()*velm.getElementSize(),
539
            energyBufferSize*energyBuffer.getElementSize()),
540
            (int) longForceBuffer.getSize()*longForceBuffer.getElementSize());
541
    pinnedBuffer = new cl::Buffer(context, CL_MEM_ALLOC_HOST_PTR, bufferBytes);
542
    pinnedMemory = currentQueue.enqueueMapBuffer(*pinnedBuffer, CL_TRUE, CL_MAP_READ | CL_MAP_WRITE, 0, bufferBytes);
543
544
545
546
547
548
549
    for (int i = 0; i < numAtoms; i++) {
        double mass = system.getParticleMass(i);
        if (useDoublePrecision || useMixedPrecision)
            ((mm_double4*) pinnedMemory)[i] = mm_double4(0.0, 0.0, 0.0, mass == 0.0 ? 0.0 : 1.0/mass);
        else
            ((mm_float4*) pinnedMemory)[i] = mm_float4(0.0f, 0.0f, 0.0f, mass == 0.0 ? 0.0f : (cl_float) (1.0/mass));
    }
peastman's avatar
peastman committed
550
    velm.upload(pinnedMemory);
551
    findMoleculeGroups();
552
    nonbonded->initialize(system);
553
554
}

555
556
void OpenCLContext::initializeContexts() {
    getPlatformData().initializeContexts(system);
557
558
}

559
560
561
562
563
void OpenCLContext::addForce(ComputeForceInfo* force) {
    ComputeContext::addForce(force);
    OpenCLForceInfo* clinfo = dynamic_cast<OpenCLForceInfo*>(force);
    if (clinfo != NULL)
        requestForceBuffers(clinfo->getRequiredForceBuffers());
564
565
}

566
567
void OpenCLContext::requestForceBuffers(int minBuffers) {
    numForceBuffers = std::max(numForceBuffers, minBuffers);
568
569
}

570
571
cl::Program OpenCLContext::createProgram(const string source, const char* optimizationFlags) {
    return createProgram(source, map<string, string>(), optimizationFlags);
572
573
}

574
cl::Program OpenCLContext::createProgram(const string source, const map<string, string>& defines, const char* optimizationFlags) {
Peter Eastman's avatar
Peter Eastman committed
575
    string options = (optimizationFlags == NULL ? defaultOptimizationOptions : string(optimizationFlags));
576
577
578
    stringstream src;
    if (!options.empty())
        src << "// Compilation Options: " << options << endl << endl;
peastman's avatar
peastman committed
579
    for (auto& pair : compilationDefines) {
580
581
582
583
584
585
586
        // Query defines to avoid duplicate variables
        if (defines.find(pair.first) == defines.end()) {
            src << "#define " << pair.first;
            if (!pair.second.empty())
                src << " " << pair.second;
            src << endl;
        }
587
588
589
    }
    if (!compilationDefines.empty())
        src << endl;
590
591
592
593
594
    if (supportsDoublePrecision)
        src << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
    if (useDoublePrecision) {
        src << "typedef double real;\n";
        src << "typedef double2 real2;\n";
595
        src << "typedef double3 real3;\n";
596
597
598
599
600
        src << "typedef double4 real4;\n";
    }
    else {
        src << "typedef float real;\n";
        src << "typedef float2 real2;\n";
601
        src << "typedef float3 real3;\n";
602
603
604
605
606
        src << "typedef float4 real4;\n";
    }
    if (useDoublePrecision || useMixedPrecision) {
        src << "typedef double mixed;\n";
        src << "typedef double2 mixed2;\n";
607
        src << "typedef double3 mixed3;\n";
608
609
610
611
612
        src << "typedef double4 mixed4;\n";
    }
    else {
        src << "typedef float mixed;\n";
        src << "typedef float2 mixed2;\n";
613
        src << "typedef float3 mixed3;\n";
614
615
        src << "typedef float4 mixed4;\n";
    }
616
    src << OpenCLKernelSources::common << endl;
peastman's avatar
peastman committed
617
618
619
620
    for (auto& pair : defines) {
        src << "#define " << pair.first;
        if (!pair.second.empty())
            src << " " << pair.second;
621
622
623
624
625
        src << endl;
    }
    if (!defines.empty())
        src << endl;
    src << source << endl;
626
    cl::Program::Sources sources({src.str()});
627
628
    cl::Program program(context, sources);
    try {
629
        program.build(vector<cl::Device>(1, device), options.c_str());
630
631
632
633
634
635
    } catch (cl::Error err) {
        throw OpenMMException("Error compiling kernel: "+program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(device));
    }
    return program;
}

636
637
638
639
640
641
642
643
644
645
646
647
cl::CommandQueue& OpenCLContext::getQueue() {
    return currentQueue;
}

void OpenCLContext::setQueue(cl::CommandQueue& queue) {
    currentQueue = queue;
}

void OpenCLContext::restoreDefaultQueue() {
    currentQueue = defaultQueue;
}

648
649
650
651
652
653
OpenCLArray* OpenCLContext::createArray() {
    return new OpenCLArray();
}

ComputeEvent OpenCLContext::createEvent() {
    return shared_ptr<ComputeEventImpl>(new OpenCLEvent(*this));
654
655
}

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
ComputeProgram OpenCLContext::compileProgram(const std::string source, const std::map<std::string, std::string>& defines) {
    cl::Program program = createProgram(source, defines);
    return shared_ptr<ComputeProgramImpl>(new OpenCLProgram(*this, program));
}

OpenCLArray& OpenCLContext::unwrap(ArrayInterface& array) const {
    OpenCLArray* clarray;
    ComputeArray* wrapper = dynamic_cast<ComputeArray*>(&array);
    if (wrapper != NULL)
        clarray = dynamic_cast<OpenCLArray*>(&wrapper->getArray());
    else
        clarray = dynamic_cast<OpenCLArray*>(&array);
    if (clarray == NULL)
        throw OpenMMException("Array argument is not an OpenCLArray");
    return *clarray;
671
672
}

673
674
675
676
void OpenCLContext::executeKernel(cl::Kernel& kernel, int workUnits, int blockSize) {
    if (blockSize == -1)
        blockSize = ThreadBlockSize;
    int size = std::min((workUnits+blockSize-1)/blockSize, numThreadBlocks)*blockSize;
677
    try {
678
        currentQueue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(size), cl::NDRange(blockSize));
679
680
681
    }
    catch (cl::Error err) {
        stringstream str;
682
        str<<"Error invoking kernel "<<kernel.getInfo<CL_KERNEL_FUNCTION_NAME>()<<": "<<err.what()<<" ("<<err.err()<<")";
683
684
685
686
        throw OpenMMException(str.str());
    }
}

687
688
689
690
691
692
693
694
695
696
697
698
699
int OpenCLContext::computeThreadBlockSize(double memory) const {
    int maxShared = device.getInfo<CL_DEVICE_LOCAL_MEM_SIZE>();
    // On some implementations, more local memory gets used than we calculate by
    // adding up the sizes of the fields.  To be safe, include a factor of 0.5.
    int max = (int) (0.5*maxShared/memory);
    if (max < 64)
        return 32;
    int threads = 64;
    while (threads+64 < max)
        threads += 64;
    return threads;
}

700
701
void OpenCLContext::clearBuffer(ArrayInterface& array) {
    clearBuffer(unwrap(array).getDeviceBuffer(), array.getSize()*array.getElementSize());
702
703
}

704
void OpenCLContext::clearBuffer(cl::Memory& memory, int size) {
705
    int words = size/4;
706
    clearBufferKernel.setArg<cl::Memory>(0, memory);
707
708
709
710
    clearBufferKernel.setArg<cl_int>(1, words);
    executeKernel(clearBufferKernel, words, 128);
}

711
712
void OpenCLContext::addAutoclearBuffer(ArrayInterface& array) {
    addAutoclearBuffer(unwrap(array).getDeviceBuffer(), array.getSize()*array.getElementSize());
713
714
}

715
716
void OpenCLContext::addAutoclearBuffer(cl::Memory& memory, int size) {
    autoclearBuffers.push_back(&memory);
717
    autoclearBufferSizes.push_back(size/4);
718
719
720
721
722
}

void OpenCLContext::clearAutoclearBuffers() {
    int base = 0;
    int total = autoclearBufferSizes.size();
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
    while (total-base >= 6) {
        clearSixBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearSixBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearSixBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearSixBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearSixBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearSixBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearSixBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearSixBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearSixBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearSixBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        clearSixBuffersKernel.setArg<cl::Memory>(10, *autoclearBuffers[base+5]);
        clearSixBuffersKernel.setArg<cl_int>(11, autoclearBufferSizes[base+5]);
        executeKernel(clearSixBuffersKernel, max(max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), autoclearBufferSizes[base+5]), 128);
        base += 6;
    }
    if (total-base == 5) {
        clearFiveBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFiveBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFiveBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFiveBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFiveBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFiveBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFiveBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFiveBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearFiveBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearFiveBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        executeKernel(clearFiveBuffersKernel, max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), 128);
    }
    else if (total-base == 4) {
753
754
755
756
757
758
759
760
        clearFourBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFourBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFourBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFourBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFourBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFourBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFourBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFourBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
761
        executeKernel(clearFourBuffersKernel, max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), 128);
762
    }
763
    else if (total-base == 3) {
764
765
766
767
768
769
        clearThreeBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearThreeBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearThreeBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearThreeBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearThreeBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearThreeBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
770
        executeKernel(clearThreeBuffersKernel, max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), 128);
771
772
773
774
775
776
    }
    else if (total-base == 2) {
        clearTwoBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearTwoBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearTwoBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearTwoBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
777
        executeKernel(clearTwoBuffersKernel, max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), 128);
778
779
    }
    else if (total-base == 1) {
780
        clearBuffer(*autoclearBuffers[base], autoclearBufferSizes[base]*4);
781
782
783
    }
}

784
void OpenCLContext::reduceForces() {
785
    executeKernel(reduceForcesKernel, paddedNumAtoms, 128);
786
787
}

788
void OpenCLContext::reduceBuffer(OpenCLArray& array, OpenCLArray& longBuffer, int numBuffers) {
789
    int bufferSize = array.getSize()/numBuffers;
790
    reduceReal4Kernel.setArg<cl::Buffer>(0, array.getDeviceBuffer());
791
792
793
    reduceReal4Kernel.setArg<cl::Buffer>(1, longBuffer.getDeviceBuffer());
    reduceReal4Kernel.setArg<cl_int>(2, bufferSize);
    reduceReal4Kernel.setArg<cl_int>(3, numBuffers);
794
    executeKernel(reduceReal4Kernel, bufferSize, 128);
795
}
796

Peter Eastman's avatar
Peter Eastman committed
797
798
799
800
double OpenCLContext::reduceEnergy() {
    int workGroupSize  = device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>();
    if (workGroupSize > 512)
        workGroupSize = 512;
peastman's avatar
peastman committed
801
802
803
    reduceEnergyKernel.setArg<cl::Buffer>(0, energyBuffer.getDeviceBuffer());
    reduceEnergyKernel.setArg<cl::Buffer>(1, energySum.getDeviceBuffer());
    reduceEnergyKernel.setArg<cl_int>(2, energyBuffer.getSize());
Peter Eastman's avatar
Peter Eastman committed
804
    reduceEnergyKernel.setArg<cl_int>(3, workGroupSize);
peastman's avatar
peastman committed
805
    reduceEnergyKernel.setArg(4, workGroupSize*energyBuffer.getElementSize(), NULL);
806
807
808
    executeKernel(reduceEnergyKernel, workGroupSize*energySum.getSize(), workGroupSize);
    energySum.download(pinnedMemory);
    double result = 0;
Peter Eastman's avatar
Peter Eastman committed
809
    if (getUseDoublePrecision() || getUseMixedPrecision()) {
810
811
        for (int i = 0; i < energySum.getSize(); i++)
            result += ((double*) pinnedMemory)[i];
Peter Eastman's avatar
Peter Eastman committed
812
813
    }
    else {
814
815
        for (int i = 0; i < energySum.getSize(); i++)
            result += ((float*) pinnedMemory)[i];
Peter Eastman's avatar
Peter Eastman committed
816
    }
817
    return result;
Peter Eastman's avatar
Peter Eastman committed
818
819
}

820
void OpenCLContext::setCharges(const vector<double>& charges) {
peastman's avatar
peastman committed
821
822
    if (!chargeBuffer.isInitialized())
        chargeBuffer.initialize(*this, numAtoms, useDoublePrecision ? sizeof(double) : sizeof(float), "chargeBuffer");
peastman's avatar
peastman committed
823
824
825
    vector<double> c(numAtoms);
    for (int i = 0; i < numAtoms; i++)
        c[i] = charges[i];
826
    chargeBuffer.upload(c, true);
peastman's avatar
peastman committed
827
828
829
    setChargesKernel.setArg<cl::Buffer>(0, chargeBuffer.getDeviceBuffer());
    setChargesKernel.setArg<cl::Buffer>(1, posq.getDeviceBuffer());
    setChargesKernel.setArg<cl::Buffer>(2, atomIndexDevice.getDeviceBuffer());
830
831
832
833
    setChargesKernel.setArg<cl_int>(3, numAtoms);
    executeKernel(setChargesKernel, numAtoms);
}

834
835
836
837
838
839
bool OpenCLContext::requestPosqCharges() {
    bool allow = !hasAssignedPosqCharges;
    hasAssignedPosqCharges = true;
    return allow;
}

840
841
842
843
844
845
846
847
848
void OpenCLContext::addEnergyParameterDerivative(const string& param) {
    // See if this parameter has already been registered.
    
    for (int i = 0; i < energyParamDerivNames.size(); i++)
        if (param == energyParamDerivNames[i])
            return;
    energyParamDerivNames.push_back(param);
}

849
850
void OpenCLContext::flushQueue() {
    getQueue().flush();
851
}