OpenCLContext.cpp 41.3 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2020 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
28
29
30
#ifdef WIN32
  #define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include <cmath>
31
32
#include "OpenCLContext.h"
#include "OpenCLArray.h"
Peter Eastman's avatar
Peter Eastman committed
33
#include "OpenCLBondedUtilities.h"
34
#include "OpenCLEvent.h"
35
#include "OpenCLForceInfo.h"
36
#include "OpenCLIntegrationUtilities.h"
37
#include "OpenCLKernelSources.h"
38
#include "OpenCLNonbondedUtilities.h"
39
40
#include "OpenCLProgram.h"
#include "openmm/common/ComputeArray.h"
41
#include "openmm/Platform.h"
42
#include "openmm/System.h"
43
#include "openmm/VirtualSite.h"
44
#include "openmm/internal/ContextImpl.h"
Peter Eastman's avatar
Peter Eastman committed
45
#include <algorithm>
46
47
#include <fstream>
#include <iostream>
48
#include <set>
49
#include <sstream>
50
#include <typeinfo>
51
52

using namespace OpenMM;
53
using namespace std;
54

55
56
57
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV 0x4000
#endif
58
59
60
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV 0x4001
#endif
61

62
63
64
const int OpenCLContext::ThreadBlockSize = 64;
const int OpenCLContext::TileSize = 32;

65
static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_info, size_t cb, void* user_data) {
66
67
68
    string skip = "OpenCL Build Warning : Compiler build log:";
    if (strncmp(errinfo, skip.c_str(), skip.length()) == 0)
        return; // OS X Lion insists on calling this for every build warning, even though they aren't errors.
69
70
71
    std::cerr << "OpenCL internal error: " << errinfo << std::endl;
}

72
73
74
75
76
77
78
79
static bool isSupported(cl::Platform platform) {
    string vendor = platform.getInfo<CL_PLATFORM_VENDOR>();
    return (vendor.find("NVIDIA") == 0 ||
            vendor.find("Advanced Micro Devices") == 0 ||
            vendor.find("Apple") == 0 ||
            vendor.find("Intel") == 0);
}

80
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, const string& precision, OpenCLPlatform::PlatformData& platformData, OpenCLContext* originalContext) :
81
        ComputeContext(system), platformData(platformData), numForceBuffers(0), hasAssignedPosqCharges(false),
82
        integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), pinnedBuffer(NULL) {
83
84
85
86
87
88
89
90
91
92
93
94
95
    if (precision == "single") {
        useDoublePrecision = false;
        useMixedPrecision = false;
    }
    else if (precision == "mixed") {
        useDoublePrecision = false;
        useMixedPrecision = true;
    }
    else if (precision == "double") {
        useDoublePrecision = true;
        useMixedPrecision = false;
    }
    else
96
        throw OpenMMException("Illegal value for Precision: "+precision);
97
    try {
98
        contextIndex = platformData.contexts.size();
99
100
        std::vector<cl::Platform> platforms;
        cl::Platform::get(&platforms);
101
102
        if (platformIndex < -1 || platformIndex >= (int) platforms.size())
            throw OpenMMException("Illegal value for OpenCLPlatformIndex: "+intToString(platformIndex));
103
104
        if (platforms.size() > 1 && platformIndex == -1 && deviceIndex != -1)
            throw OpenMMException("Specified DeviceIndex but not OpenCLPlatformIndex.  When multiple platforms are available, a platform index is needed to specify a device.");
Robert McGibbon's avatar
Robert McGibbon committed
105
        const int minThreadBlockSize = 32;
106

Robert McGibbon's avatar
Robert McGibbon committed
107
108
109
        int bestSpeed = -1;
        int bestDevice = -1;
        int bestPlatform = -1;
110
        bool bestSupported = false;
Robert McGibbon's avatar
Robert McGibbon committed
111
        for (int j = 0; j < platforms.size(); j++) {
112
113
            // If they supplied a valid platformIndex, we only look through that platform
            if (j != platformIndex && platformIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
114
115
                continue;

116
117
118
119
            // Always prefer a supported platform over an unsupported one.
            bool supported = isSupported(platforms[j]);
            if (!supported && bestSupported)
                continue;
Robert McGibbon's avatar
Robert McGibbon committed
120
121
            string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>();
            vector<cl::Device> devices;
122
123
124
125
126
127
128
            try {
                platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices);
            }
            catch (...) {
                // There are no devices available for this platform.
                continue;
            }
129
            if (deviceIndex < -1 || deviceIndex >= (int) devices.size())
130
                throw OpenMMException("Illegal value for DeviceIndex: "+intToString(deviceIndex));
Robert McGibbon's avatar
Robert McGibbon committed
131

132
            for (int i = 0; i < (int) devices.size(); i++) {
133
134
                // If they supplied a valid deviceIndex, we only look through that one
                if (i != deviceIndex && deviceIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
135
                    continue;
136
137
                if (platformVendor == "Apple" && (devices[i].getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU))
                    continue; // The CPU device on OS X won't work correctly.
138
139
140
141
142
                if (useMixedPrecision || useDoublePrecision) {
                    bool supportsDouble = (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
                    if (!supportsDouble)
                        continue; // This device does not support double precision.
                }
143
                int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0];
144
145
146
147
148
                int processingElementsPerComputeUnit = 8;
                if (devices[i].getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                    processingElementsPerComputeUnit = 1;
                }
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
149
150
151
152
                    cl_uint computeCapabilityMajor;
                    clGetDeviceInfo(devices[i](), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
                    processingElementsPerComputeUnit = (computeCapabilityMajor < 2 ? 8 : 32);
                }
153
154
155
156
157
158
159
160
161
162
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime (it may be an older runtime,
                    // or the CPU device) so still have to check for errors.
                    try {
                        processingElementsPerComputeUnit =
                            // AMD GPUs either have a single VLIW SIMD or multiple scalar SIMDs.
                            // The SIMD width is the number of threads the SIMD executes per cycle.
                            // This will be less than the wavefront width since it takes several
                            // cycles to execute the full wavefront.
                            // The SIMD instruction width is the VLIW instruction width (or 1 for scalar),
163
                            // this is the number of ALUs that can be executing per instruction per thread.
164
165
166
167
168
169
170
171
172
173
174
                            devices[i].getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_WIDTH_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_INSTRUCTION_WIDTH_AMD>();
                        // Just in case any of the queries return 0.
                        if (processingElementsPerComputeUnit <= 0)
                            processingElementsPerComputeUnit = 1;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the queries so use default.
                    }
                }
175
                int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*processingElementsPerComputeUnit*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
176
                if (maxSize >= minThreadBlockSize && (speed > bestSpeed || (supported && !bestSupported))) {
Robert McGibbon's avatar
Robert McGibbon committed
177
                    bestDevice = i;
178
                    bestSpeed = speed;
Robert McGibbon's avatar
Robert McGibbon committed
179
                    bestPlatform = j;
180
                    bestSupported = supported;
181
                }
182
            }
183
        }
Robert McGibbon's avatar
Robert McGibbon committed
184
185
186
187
188

        if (bestPlatform == -1)
            throw OpenMMException("No compatible OpenCL platform is available");

        if (bestDevice == -1)
189
            throw OpenMMException("No compatible OpenCL device is available");
190
191
192
        
        if (!bestSupported)
            cout << "WARNING: Using an unsupported OpenCL implementation.  Results may be incorrect." << endl;
Robert McGibbon's avatar
Robert McGibbon committed
193
194
195

        vector<cl::Device> devices;
        platforms[bestPlatform].getDevices(CL_DEVICE_TYPE_ALL, &devices);
Robert McGibbon's avatar
Robert McGibbon committed
196
        string platformVendor = platforms[bestPlatform].getInfo<CL_PLATFORM_VENDOR>();
Robert McGibbon's avatar
Robert McGibbon committed
197
198
199
        device = devices[bestDevice];

        this->deviceIndex = bestDevice;
Robert McGibbon's avatar
Robert McGibbon committed
200
        this->platformIndex = bestPlatform;
201
        if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize)
202
            throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
203
        compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize);
Peter Eastman's avatar
Peter Eastman committed
204
        if (platformVendor.size() >= 5 && platformVendor.substr(0, 5) == "Intel")
205
206
            defaultOptimizationOptions = "";
        else
207
            defaultOptimizationOptions = "-cl-mad-enable -cl-no-signed-zeros";
208
        supports64BitGlobalAtomics = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_int64_base_atomics") != string::npos);
209
        supportsDoublePrecision = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
210
211
        if ((useDoublePrecision || useMixedPrecision) && !supportsDoublePrecision)
            throw OpenMMException("This device does not support double precision");
212
        string vendor = device.getInfo<CL_DEVICE_VENDOR>();
213
        int numThreadBlocksPerComputeUnit = 6;
214
        if (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA") {
215
            compilationDefines["WARPS_ARE_ATOMIC"] = "";
216
            simdWidth = 32;
217
218
            if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
                // Compute level 1.2 and later Nvidia GPUs support 64 bit atomics, even though they don't list the
219
220
                // proper extension as supported.  We only use them on compute level 2.0 or later, since they're very
                // slow on earlier GPUs.
221

222
                cl_uint computeCapabilityMajor;
223
                clGetDeviceInfo(device(), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
224
                if (computeCapabilityMajor > 1)
225
                    supports64BitGlobalAtomics = true;
226
227
228
229
230
231
232
                if (computeCapabilityMajor == 5) {
                    // Workaround for a bug in Maxwell on CUDA 6.x.

                    string platformVersion = platforms[bestPlatform].getInfo<CL_PLATFORM_VERSION>();
                    if (platformVersion.find("CUDA 6") != string::npos)
                        supports64BitGlobalAtomics = false;
                }
233
            }
234
        }
235
        else if (vendor.size() >= 28 && vendor.substr(0, 28) == "Advanced Micro Devices, Inc.") {
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            if (device.getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                /// \todo Is 6 a good value for the OpenCL CPU device?
                // numThreadBlocksPerComputeUnit = ?;
                simdWidth = 1;
            }
            else {
                bool amdPostSdk2_4 = false;
                // Default to 1 which will use the default kernels.
                simdWidth = 1;
                if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime so still have to
                    // check for errors.
                    try {
                        // Must catch cl:Error as will fail if runtime does not support queries.

                        cl_uint simdPerComputeUnit = device.getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>();
252
253
                        simdWidth = device.getInfo<CL_DEVICE_WAVEFRONT_WIDTH_AMD>();

254
255
256
                        // If the GPU has multiple SIMDs per compute unit then it is uses the scalar instruction
                        // set instead of the VLIW instruction set. It therefore needs more thread blocks per
                        // compute unit to hide memory latency.
Peter Eastman's avatar
Peter Eastman committed
257
258
259
260
261
262
                        if (simdPerComputeUnit > 1) {
                            if (simdWidth == 32)
                                numThreadBlocksPerComputeUnit = 6*simdPerComputeUnit; // Navi seems to like more thread blocks than older GPUs
                            else
                                numThreadBlocksPerComputeUnit = 4*simdPerComputeUnit;
                        }
263
264
265
266
267
268
269
270
271
272
273
274
275

                        // If the queries are supported then must be newer than SDK 2.4.
                        amdPostSdk2_4 = true;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the query so is unlikely to be the newer scalar GPU.
                        // Stay with the default simdWidth and numThreadBlocksPerComputeUnit.
                    }
                }
                // AMD APP SDK 2.4 has a performance problem with atomics. Enable the work around. This is fixed after SDK 2.4.
                if (!amdPostSdk2_4)
                    compilationDefines["AMD_ATOMIC_WORK_AROUND"] = "";
            }
276
        }
277
278
        else
            simdWidth = 1;
279
        if (supports64BitGlobalAtomics)
280
            compilationDefines["SUPPORTS_64_BIT_ATOMICS"] = "";
281
282
        if (supportsDoublePrecision)
            compilationDefines["SUPPORTS_DOUBLE_PRECISION"] = "";
283
        if (simdWidth >= 32)
Peter Eastman's avatar
Peter Eastman committed
284
            compilationDefines["SYNC_WARPS"] = "mem_fence(CLK_LOCAL_MEM_FENCE)";
285
286
        else
            compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)";
287
288
        vector<cl::Device> contextDevices;
        contextDevices.push_back(device);
Robert McGibbon's avatar
Robert McGibbon committed
289
        cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
290
291
292
293
294
295
296
297
        if (originalContext == NULL) {
            context = cl::Context(contextDevices, cprops, errorCallback);
            defaultQueue = cl::CommandQueue(context, device);
        }
        else {
            context = originalContext->context;
            defaultQueue = originalContext->defaultQueue;
        }
298
        currentQueue = defaultQueue;
Peter Eastman's avatar
Peter Eastman committed
299
300
        numAtoms = system.getNumParticles();
        paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
301
        numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
302
        numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
303
        if (useDoublePrecision) {
peastman's avatar
peastman committed
304
305
            posq.initialize<mm_double4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_double4>(*this, paddedNumAtoms, "velm");
306
307
            compilationDefines["USE_DOUBLE_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_double4";
308
309
310
            compilationDefines["make_real2"] = "make_double2";
            compilationDefines["make_real3"] = "make_double3";
            compilationDefines["make_real4"] = "make_double4";
311
            compilationDefines["convert_mixed4"] = "convert_double4";
312
313
314
            compilationDefines["make_mixed2"] = "make_double2";
            compilationDefines["make_mixed3"] = "make_double3";
            compilationDefines["make_mixed4"] = "make_double4";
315
316
        }
        else if (useMixedPrecision) {
peastman's avatar
peastman committed
317
318
319
            posq.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            posqCorrection.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_double4>(*this, paddedNumAtoms, "velm");
320
321
            compilationDefines["USE_MIXED_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_float4";
322
323
324
            compilationDefines["make_real2"] = "make_float2";
            compilationDefines["make_real3"] = "make_float3";
            compilationDefines["make_real4"] = "make_float4";
325
            compilationDefines["convert_mixed4"] = "convert_double4";
326
327
328
            compilationDefines["make_mixed2"] = "make_double2";
            compilationDefines["make_mixed3"] = "make_double3";
            compilationDefines["make_mixed4"] = "make_double4";
329
330
        }
        else {
peastman's avatar
peastman committed
331
332
            posq.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_float4>(*this, paddedNumAtoms, "velm");
333
            compilationDefines["convert_real4"] = "convert_float4";
334
335
336
            compilationDefines["make_real2"] = "make_float2";
            compilationDefines["make_real3"] = "make_float3";
            compilationDefines["make_real4"] = "make_float4";
337
            compilationDefines["convert_mixed4"] = "convert_float4";
338
339
340
            compilationDefines["make_mixed2"] = "make_float2";
            compilationDefines["make_mixed3"] = "make_float3";
            compilationDefines["make_mixed4"] = "make_float4";
341
        }
342
        longForceBuffer.initialize<cl_long>(*this, 3*paddedNumAtoms, "longForceBuffer");
343
        posCellOffsets.resize(paddedNumAtoms, mm_int4(0, 0, 0, 0));
344
345
346
347
348
        atomIndexDevice.initialize<cl_int>(*this, paddedNumAtoms, "atomIndexDevice");
        atomIndex.resize(paddedNumAtoms);
        for (int i = 0; i < paddedNumAtoms; ++i)
            atomIndex[i] = i;
        atomIndexDevice.upload(atomIndex);
349
350
351
352
353
    }
    catch (cl::Error err) {
        std::stringstream str;
        str<<"Error initializing context: "<<err.what()<<" ("<<err.err()<<")";
        throw OpenMMException(str.str());
354
    }
355
356
357

    // Create utility kernels that are used in multiple places.

Peter Eastman's avatar
Peter Eastman committed
358
    cl::Program utilities = createProgram(OpenCLKernelSources::utilities);
359
    clearBufferKernel = cl::Kernel(utilities, "clearBuffer");
360
361
362
    clearTwoBuffersKernel = cl::Kernel(utilities, "clearTwoBuffers");
    clearThreeBuffersKernel = cl::Kernel(utilities, "clearThreeBuffers");
    clearFourBuffersKernel = cl::Kernel(utilities, "clearFourBuffers");
363
364
    clearFiveBuffersKernel = cl::Kernel(utilities, "clearFiveBuffers");
    clearSixBuffersKernel = cl::Kernel(utilities, "clearSixBuffers");
365
    reduceReal4Kernel = cl::Kernel(utilities, "reduceReal4Buffer");
366
    reduceForcesKernel = cl::Kernel(utilities, "reduceForces");
Peter Eastman's avatar
Peter Eastman committed
367
    reduceEnergyKernel = cl::Kernel(utilities, "reduceEnergy");
368
    setChargesKernel = cl::Kernel(utilities, "setCharges");
369
370
371

    // Decide whether native_sqrt(), native_rsqrt(), and native_recip() are sufficiently accurate to use.

372
373
374
375
376
    if (!useDoublePrecision) {
        cl::Kernel accuracyKernel(utilities, "determineNativeAccuracy");
        OpenCLArray valuesArray(*this, 20, sizeof(mm_float8), "values");
        vector<mm_float8> values(valuesArray.getSize());
        float nextValue = 1e-4f;
peastman's avatar
peastman committed
377
378
        for (auto& val : values) {
            val.s0 = nextValue;
379
380
381
382
383
384
385
386
            nextValue *= (float) M_PI;
        }
        valuesArray.upload(values);
        accuracyKernel.setArg<cl::Buffer>(0, valuesArray.getDeviceBuffer());
        accuracyKernel.setArg<cl_int>(1, values.size());
        executeKernel(accuracyKernel, values.size());
        valuesArray.download(values);
        double maxSqrtError = 0.0, maxRsqrtError = 0.0, maxRecipError = 0.0, maxExpError = 0.0, maxLogError = 0.0;
peastman's avatar
peastman committed
387
388
        for (auto& val : values) {
            double v = val.s0;
389
            double correctSqrt = sqrt(v);
peastman's avatar
peastman committed
390
391
392
393
394
            maxSqrtError = max(maxSqrtError, fabs(correctSqrt-val.s1)/correctSqrt);
            maxRsqrtError = max(maxRsqrtError, fabs(1.0/correctSqrt-val.s2)*correctSqrt);
            maxRecipError = max(maxRecipError, fabs(1.0/v-val.s3)/val.s3);
            maxExpError = max(maxExpError, fabs(exp(v)-val.s4)/val.s4);
            maxLogError = max(maxLogError, fabs(log(v)-val.s5)/val.s5);
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        }
        compilationDefines["SQRT"] = (maxSqrtError < 1e-6) ? "native_sqrt" : "sqrt";
        compilationDefines["RSQRT"] = (maxRsqrtError < 1e-6) ? "native_rsqrt" : "rsqrt";
        compilationDefines["RECIP"] = (maxRecipError < 1e-6) ? "native_recip" : "1.0f/";
        compilationDefines["EXP"] = (maxExpError < 1e-6) ? "native_exp" : "exp";
        compilationDefines["LOG"] = (maxLogError < 1e-6) ? "native_log" : "log";
    }
    else {
        compilationDefines["SQRT"] = "sqrt";
        compilationDefines["RSQRT"] = "rsqrt";
        compilationDefines["RECIP"] = "1.0/";
        compilationDefines["EXP"] = "exp";
        compilationDefines["LOG"] = "log";
    }
409
410
411
412
413
414
415
    compilationDefines["POW"] = "pow";
    compilationDefines["COS"] = "cos";
    compilationDefines["SIN"] = "sin";
    compilationDefines["TAN"] = "tan";
    compilationDefines["ACOS"] = "acos";
    compilationDefines["ASIN"] = "asin";
    compilationDefines["ATAN"] = "atan";
416
417
    compilationDefines["ERF"] = "erf";
    compilationDefines["ERFC"] = "erfc";
418

419
    // Set defines for applying periodic boundary conditions.
420

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    Vec3 boxVectors[3];
    system.getDefaultPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
    boxIsTriclinic = (boxVectors[0][1] != 0.0 || boxVectors[0][2] != 0.0 ||
                      boxVectors[1][0] != 0.0 || boxVectors[1][2] != 0.0 ||
                      boxVectors[2][0] != 0.0 || boxVectors[2][1] != 0.0);
    if (boxIsTriclinic) {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "{"
            "real scale3 = floor(delta.z*invPeriodicBoxSize.z+0.5f); \\\n"
            "delta.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(delta.y*invPeriodicBoxSize.y+0.5f); \\\n"
            "delta.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(delta.x*invPeriodicBoxSize.x+0.5f); \\\n"
            "delta.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "{"
            "real scale3 = floor(pos.z*invPeriodicBoxSize.z); \\\n"
            "pos.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(pos.y*invPeriodicBoxSize.y); \\\n"
            "pos.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(pos.x*invPeriodicBoxSize.x); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "real scale3 = floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f); \\\n"
            "pos.x -= scale3*periodicBoxVecZ.x; \\\n"
            "pos.y -= scale3*periodicBoxVecZ.y; \\\n"
            "pos.z -= scale3*periodicBoxVecZ.z; \\\n"
            "real scale2 = floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f); \\\n"
            "pos.x -= scale2*periodicBoxVecY.x; \\\n"
            "pos.y -= scale2*periodicBoxVecY.y; \\\n"
            "real scale1 = floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
    }
    else {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "pos.xyz -= floor(pos.xyz*invPeriodicBoxSize.xyz)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "pos.x -= floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; \\\n"
            "pos.y -= floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y; \\\n"
            "pos.z -= floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;}";
    }

467
    // Create utilities objects.
468

469
470
    bonded = new OpenCLBondedUtilities(*this);
    nonbonded = new OpenCLNonbondedUtilities(*this);
Peter Eastman's avatar
Peter Eastman committed
471
    integration = new OpenCLIntegrationUtilities(*this, system);
472
    expression = new OpenCLExpressionUtilities(*this);
473
474
475
}

OpenCLContext::~OpenCLContext() {
peastman's avatar
peastman committed
476
477
478
479
480
481
482
483
    for (auto force : forces)
        delete force;
    for (auto listener : reorderListeners)
        delete listener;
    for (auto computation : preComputations)
        delete computation;
    for (auto computation : postComputations)
        delete computation;
484
485
    if (pinnedBuffer != NULL)
        delete pinnedBuffer;
486
487
    if (integration != NULL)
        delete integration;
488
489
    if (expression != NULL)
        delete expression;
Peter Eastman's avatar
Peter Eastman committed
490
491
    if (bonded != NULL)
        delete bonded;
492
493
    if (nonbonded != NULL)
        delete nonbonded;
494
495
}

496
void OpenCLContext::initialize() {
Peter Eastman's avatar
Peter Eastman committed
497
    bonded->initialize(system);
498
    numForceBuffers = std::max(numForceBuffers, (int) platformData.contexts.size());
499
    int energyBufferSize = max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers());
500
    if (useDoublePrecision) {
peastman's avatar
peastman committed
501
502
503
504
        forceBuffers.initialize<mm_double4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_double4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_double>(*this, energyBufferSize, "energyBuffer");
        energySum.initialize<cl_double>(*this, 1, "energySum");
505
    }
Peter Eastman's avatar
Peter Eastman committed
506
    else if (useMixedPrecision) {
peastman's avatar
peastman committed
507
508
509
510
        forceBuffers.initialize<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_float4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_double>(*this, energyBufferSize, "energyBuffer");
        energySum.initialize<cl_double>(*this, 1, "energySum");
Peter Eastman's avatar
Peter Eastman committed
511
512
    }
    else {
peastman's avatar
peastman committed
513
514
515
516
        forceBuffers.initialize<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_float4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_float>(*this, energyBufferSize, "energyBuffer");
        energySum.initialize<cl_float>(*this, 1, "energySum");
517
    }
518
519
520
521
    reduceForcesKernel.setArg<cl::Buffer>(0, longForceBuffer.getDeviceBuffer());
    reduceForcesKernel.setArg<cl::Buffer>(1, forceBuffers.getDeviceBuffer());
    reduceForcesKernel.setArg<cl_int>(2, paddedNumAtoms);
    reduceForcesKernel.setArg<cl_int>(3, numForceBuffers);
522
    addAutoclearBuffer(longForceBuffer);
peastman's avatar
peastman committed
523
524
    addAutoclearBuffer(forceBuffers);
    addAutoclearBuffer(energyBuffer);
525
526
527
    int numEnergyParamDerivs = energyParamDerivNames.size();
    if (numEnergyParamDerivs > 0) {
        if (useDoublePrecision || useMixedPrecision)
peastman's avatar
peastman committed
528
            energyParamDerivBuffer.initialize<cl_double>(*this, numEnergyParamDerivs*energyBufferSize, "energyParamDerivBuffer");
529
        else
peastman's avatar
peastman committed
530
531
            energyParamDerivBuffer.initialize<cl_float>(*this, numEnergyParamDerivs*energyBufferSize, "energyParamDerivBuffer");
        addAutoclearBuffer(energyParamDerivBuffer);
532
    }
533
    int bufferBytes = max(max((int) velm.getSize()*velm.getElementSize(),
534
            energyBufferSize*energyBuffer.getElementSize()),
535
            (int) longForceBuffer.getSize()*longForceBuffer.getElementSize());
536
    pinnedBuffer = new cl::Buffer(context, CL_MEM_ALLOC_HOST_PTR, bufferBytes);
537
    pinnedMemory = currentQueue.enqueueMapBuffer(*pinnedBuffer, CL_TRUE, CL_MAP_READ | CL_MAP_WRITE, 0, bufferBytes);
538
539
540
541
542
543
544
    for (int i = 0; i < numAtoms; i++) {
        double mass = system.getParticleMass(i);
        if (useDoublePrecision || useMixedPrecision)
            ((mm_double4*) pinnedMemory)[i] = mm_double4(0.0, 0.0, 0.0, mass == 0.0 ? 0.0 : 1.0/mass);
        else
            ((mm_float4*) pinnedMemory)[i] = mm_float4(0.0f, 0.0f, 0.0f, mass == 0.0 ? 0.0f : (cl_float) (1.0/mass));
    }
peastman's avatar
peastman committed
545
    velm.upload(pinnedMemory);
546
    findMoleculeGroups();
547
    nonbonded->initialize(system);
548
549
}

550
551
void OpenCLContext::initializeContexts() {
    getPlatformData().initializeContexts(system);
552
553
}

554
555
556
557
558
void OpenCLContext::addForce(ComputeForceInfo* force) {
    ComputeContext::addForce(force);
    OpenCLForceInfo* clinfo = dynamic_cast<OpenCLForceInfo*>(force);
    if (clinfo != NULL)
        requestForceBuffers(clinfo->getRequiredForceBuffers());
559
560
}

561
562
void OpenCLContext::requestForceBuffers(int minBuffers) {
    numForceBuffers = std::max(numForceBuffers, minBuffers);
563
564
}

565
566
cl::Program OpenCLContext::createProgram(const string source, const char* optimizationFlags) {
    return createProgram(source, map<string, string>(), optimizationFlags);
567
568
}

569
cl::Program OpenCLContext::createProgram(const string source, const map<string, string>& defines, const char* optimizationFlags) {
Peter Eastman's avatar
Peter Eastman committed
570
    string options = (optimizationFlags == NULL ? defaultOptimizationOptions : string(optimizationFlags));
571
572
573
    stringstream src;
    if (!options.empty())
        src << "// Compilation Options: " << options << endl << endl;
peastman's avatar
peastman committed
574
    for (auto& pair : compilationDefines) {
575
576
577
578
579
580
581
        // Query defines to avoid duplicate variables
        if (defines.find(pair.first) == defines.end()) {
            src << "#define " << pair.first;
            if (!pair.second.empty())
                src << " " << pair.second;
            src << endl;
        }
582
583
584
    }
    if (!compilationDefines.empty())
        src << endl;
585
586
587
588
589
    if (supportsDoublePrecision)
        src << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
    if (useDoublePrecision) {
        src << "typedef double real;\n";
        src << "typedef double2 real2;\n";
590
        src << "typedef double3 real3;\n";
591
592
593
594
595
        src << "typedef double4 real4;\n";
    }
    else {
        src << "typedef float real;\n";
        src << "typedef float2 real2;\n";
596
        src << "typedef float3 real3;\n";
597
598
599
600
601
        src << "typedef float4 real4;\n";
    }
    if (useDoublePrecision || useMixedPrecision) {
        src << "typedef double mixed;\n";
        src << "typedef double2 mixed2;\n";
602
        src << "typedef double3 mixed3;\n";
603
604
605
606
607
        src << "typedef double4 mixed4;\n";
    }
    else {
        src << "typedef float mixed;\n";
        src << "typedef float2 mixed2;\n";
608
        src << "typedef float3 mixed3;\n";
609
610
        src << "typedef float4 mixed4;\n";
    }
611
    src << OpenCLKernelSources::common << endl;
peastman's avatar
peastman committed
612
613
614
615
    for (auto& pair : defines) {
        src << "#define " << pair.first;
        if (!pair.second.empty())
            src << " " << pair.second;
616
617
618
619
620
        src << endl;
    }
    if (!defines.empty())
        src << endl;
    src << source << endl;
621
    cl::Program::Sources sources({src.str()});
622
623
    cl::Program program(context, sources);
    try {
624
        program.build(vector<cl::Device>(1, device), options.c_str());
625
626
627
628
629
630
    } catch (cl::Error err) {
        throw OpenMMException("Error compiling kernel: "+program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(device));
    }
    return program;
}

631
632
633
634
635
636
637
638
639
640
641
642
cl::CommandQueue& OpenCLContext::getQueue() {
    return currentQueue;
}

void OpenCLContext::setQueue(cl::CommandQueue& queue) {
    currentQueue = queue;
}

void OpenCLContext::restoreDefaultQueue() {
    currentQueue = defaultQueue;
}

643
644
645
646
647
648
OpenCLArray* OpenCLContext::createArray() {
    return new OpenCLArray();
}

ComputeEvent OpenCLContext::createEvent() {
    return shared_ptr<ComputeEventImpl>(new OpenCLEvent(*this));
649
650
}

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
ComputeProgram OpenCLContext::compileProgram(const std::string source, const std::map<std::string, std::string>& defines) {
    cl::Program program = createProgram(source, defines);
    return shared_ptr<ComputeProgramImpl>(new OpenCLProgram(*this, program));
}

OpenCLArray& OpenCLContext::unwrap(ArrayInterface& array) const {
    OpenCLArray* clarray;
    ComputeArray* wrapper = dynamic_cast<ComputeArray*>(&array);
    if (wrapper != NULL)
        clarray = dynamic_cast<OpenCLArray*>(&wrapper->getArray());
    else
        clarray = dynamic_cast<OpenCLArray*>(&array);
    if (clarray == NULL)
        throw OpenMMException("Array argument is not an OpenCLArray");
    return *clarray;
666
667
}

668
669
670
671
void OpenCLContext::executeKernel(cl::Kernel& kernel, int workUnits, int blockSize) {
    if (blockSize == -1)
        blockSize = ThreadBlockSize;
    int size = std::min((workUnits+blockSize-1)/blockSize, numThreadBlocks)*blockSize;
672
    try {
673
        currentQueue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(size), cl::NDRange(blockSize));
674
675
676
    }
    catch (cl::Error err) {
        stringstream str;
677
        str<<"Error invoking kernel "<<kernel.getInfo<CL_KERNEL_FUNCTION_NAME>()<<": "<<err.what()<<" ("<<err.err()<<")";
678
679
680
681
        throw OpenMMException(str.str());
    }
}

682
683
684
685
686
687
688
689
690
691
692
693
694
int OpenCLContext::computeThreadBlockSize(double memory) const {
    int maxShared = device.getInfo<CL_DEVICE_LOCAL_MEM_SIZE>();
    // On some implementations, more local memory gets used than we calculate by
    // adding up the sizes of the fields.  To be safe, include a factor of 0.5.
    int max = (int) (0.5*maxShared/memory);
    if (max < 64)
        return 32;
    int threads = 64;
    while (threads+64 < max)
        threads += 64;
    return threads;
}

695
696
void OpenCLContext::clearBuffer(ArrayInterface& array) {
    clearBuffer(unwrap(array).getDeviceBuffer(), array.getSize()*array.getElementSize());
697
698
}

699
void OpenCLContext::clearBuffer(cl::Memory& memory, int size) {
700
    int words = size/4;
701
    clearBufferKernel.setArg<cl::Memory>(0, memory);
702
703
704
705
    clearBufferKernel.setArg<cl_int>(1, words);
    executeKernel(clearBufferKernel, words, 128);
}

706
707
void OpenCLContext::addAutoclearBuffer(ArrayInterface& array) {
    addAutoclearBuffer(unwrap(array).getDeviceBuffer(), array.getSize()*array.getElementSize());
708
709
}

710
711
void OpenCLContext::addAutoclearBuffer(cl::Memory& memory, int size) {
    autoclearBuffers.push_back(&memory);
712
    autoclearBufferSizes.push_back(size/4);
713
714
715
716
717
}

void OpenCLContext::clearAutoclearBuffers() {
    int base = 0;
    int total = autoclearBufferSizes.size();
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
    while (total-base >= 6) {
        clearSixBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearSixBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearSixBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearSixBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearSixBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearSixBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearSixBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearSixBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearSixBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearSixBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        clearSixBuffersKernel.setArg<cl::Memory>(10, *autoclearBuffers[base+5]);
        clearSixBuffersKernel.setArg<cl_int>(11, autoclearBufferSizes[base+5]);
        executeKernel(clearSixBuffersKernel, max(max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), autoclearBufferSizes[base+5]), 128);
        base += 6;
    }
    if (total-base == 5) {
        clearFiveBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFiveBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFiveBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFiveBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFiveBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFiveBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFiveBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFiveBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearFiveBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearFiveBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        executeKernel(clearFiveBuffersKernel, max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), 128);
    }
    else if (total-base == 4) {
748
749
750
751
752
753
754
755
        clearFourBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFourBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFourBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFourBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFourBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFourBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFourBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFourBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
756
        executeKernel(clearFourBuffersKernel, max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), 128);
757
    }
758
    else if (total-base == 3) {
759
760
761
762
763
764
        clearThreeBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearThreeBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearThreeBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearThreeBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearThreeBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearThreeBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
765
        executeKernel(clearThreeBuffersKernel, max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), 128);
766
767
768
769
770
771
    }
    else if (total-base == 2) {
        clearTwoBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearTwoBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearTwoBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearTwoBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
772
        executeKernel(clearTwoBuffersKernel, max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), 128);
773
774
    }
    else if (total-base == 1) {
775
        clearBuffer(*autoclearBuffers[base], autoclearBufferSizes[base]*4);
776
777
778
    }
}

779
void OpenCLContext::reduceForces() {
780
    executeKernel(reduceForcesKernel, paddedNumAtoms, 128);
781
782
}

783
void OpenCLContext::reduceBuffer(OpenCLArray& array, OpenCLArray& longBuffer, int numBuffers) {
784
    int bufferSize = array.getSize()/numBuffers;
785
    reduceReal4Kernel.setArg<cl::Buffer>(0, array.getDeviceBuffer());
786
787
788
    reduceReal4Kernel.setArg<cl::Buffer>(1, longBuffer.getDeviceBuffer());
    reduceReal4Kernel.setArg<cl_int>(2, bufferSize);
    reduceReal4Kernel.setArg<cl_int>(3, numBuffers);
789
    executeKernel(reduceReal4Kernel, bufferSize, 128);
790
}
791

Peter Eastman's avatar
Peter Eastman committed
792
793
794
795
double OpenCLContext::reduceEnergy() {
    int workGroupSize  = device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>();
    if (workGroupSize > 512)
        workGroupSize = 512;
peastman's avatar
peastman committed
796
797
798
    reduceEnergyKernel.setArg<cl::Buffer>(0, energyBuffer.getDeviceBuffer());
    reduceEnergyKernel.setArg<cl::Buffer>(1, energySum.getDeviceBuffer());
    reduceEnergyKernel.setArg<cl_int>(2, energyBuffer.getSize());
Peter Eastman's avatar
Peter Eastman committed
799
    reduceEnergyKernel.setArg<cl_int>(3, workGroupSize);
peastman's avatar
peastman committed
800
    reduceEnergyKernel.setArg(4, workGroupSize*energyBuffer.getElementSize(), NULL);
Peter Eastman's avatar
Peter Eastman committed
801
802
803
    executeKernel(reduceEnergyKernel, workGroupSize, workGroupSize);
    if (getUseDoublePrecision() || getUseMixedPrecision()) {
        double energy;
peastman's avatar
peastman committed
804
        energySum.download(&energy);
Peter Eastman's avatar
Peter Eastman committed
805
806
807
808
        return energy;
    }
    else {
        float energy;
peastman's avatar
peastman committed
809
        energySum.download(&energy);
Peter Eastman's avatar
Peter Eastman committed
810
811
812
813
        return energy;
    }
}

814
void OpenCLContext::setCharges(const vector<double>& charges) {
peastman's avatar
peastman committed
815
816
    if (!chargeBuffer.isInitialized())
        chargeBuffer.initialize(*this, numAtoms, useDoublePrecision ? sizeof(double) : sizeof(float), "chargeBuffer");
peastman's avatar
peastman committed
817
818
819
    vector<double> c(numAtoms);
    for (int i = 0; i < numAtoms; i++)
        c[i] = charges[i];
820
    chargeBuffer.upload(c, true);
peastman's avatar
peastman committed
821
822
823
    setChargesKernel.setArg<cl::Buffer>(0, chargeBuffer.getDeviceBuffer());
    setChargesKernel.setArg<cl::Buffer>(1, posq.getDeviceBuffer());
    setChargesKernel.setArg<cl::Buffer>(2, atomIndexDevice.getDeviceBuffer());
824
825
826
827
    setChargesKernel.setArg<cl_int>(3, numAtoms);
    executeKernel(setChargesKernel, numAtoms);
}

828
829
830
831
832
833
bool OpenCLContext::requestPosqCharges() {
    bool allow = !hasAssignedPosqCharges;
    hasAssignedPosqCharges = true;
    return allow;
}

834
835
836
837
838
839
840
841
842
void OpenCLContext::addEnergyParameterDerivative(const string& param) {
    // See if this parameter has already been registered.
    
    for (int i = 0; i < energyParamDerivNames.size(); i++)
        if (param == energyParamDerivNames[i])
            return;
    energyParamDerivNames.push_back(param);
}

843
844
void OpenCLContext::flushQueue() {
    getQueue().flush();
845
}