OpenCLContext.cpp 40.9 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2020 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
28
29
30
#ifdef WIN32
  #define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include <cmath>
31
32
#include "OpenCLContext.h"
#include "OpenCLArray.h"
Peter Eastman's avatar
Peter Eastman committed
33
#include "OpenCLBondedUtilities.h"
34
#include "OpenCLEvent.h"
35
#include "OpenCLForceInfo.h"
36
#include "OpenCLIntegrationUtilities.h"
37
#include "OpenCLKernelSources.h"
38
#include "OpenCLNonbondedUtilities.h"
39
40
#include "OpenCLProgram.h"
#include "openmm/common/ComputeArray.h"
41
#include "openmm/Platform.h"
42
#include "openmm/System.h"
43
#include "openmm/VirtualSite.h"
44
#include "openmm/internal/ContextImpl.h"
Peter Eastman's avatar
Peter Eastman committed
45
#include <algorithm>
46
47
#include <fstream>
#include <iostream>
48
#include <set>
49
#include <sstream>
50
#include <typeinfo>
51
52

using namespace OpenMM;
53
using namespace std;
54

55
56
57
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV 0x4000
#endif
58
59
60
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV 0x4001
#endif
61

62
63
64
const int OpenCLContext::ThreadBlockSize = 64;
const int OpenCLContext::TileSize = 32;

65
static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_info, size_t cb, void* user_data) {
66
67
68
    string skip = "OpenCL Build Warning : Compiler build log:";
    if (strncmp(errinfo, skip.c_str(), skip.length()) == 0)
        return; // OS X Lion insists on calling this for every build warning, even though they aren't errors.
69
70
71
    std::cerr << "OpenCL internal error: " << errinfo << std::endl;
}

72
73
74
75
76
77
78
79
static bool isSupported(cl::Platform platform) {
    string vendor = platform.getInfo<CL_PLATFORM_VENDOR>();
    return (vendor.find("NVIDIA") == 0 ||
            vendor.find("Advanced Micro Devices") == 0 ||
            vendor.find("Apple") == 0 ||
            vendor.find("Intel") == 0);
}

80
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, const string& precision, OpenCLPlatform::PlatformData& platformData, OpenCLContext* originalContext) :
81
82
        ComputeContext(system), platformData(platformData), numForceBuffers(0), hasAssignedPosqCharges(false),
        integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL) {
83
84
85
86
87
88
89
90
91
92
93
94
95
    if (precision == "single") {
        useDoublePrecision = false;
        useMixedPrecision = false;
    }
    else if (precision == "mixed") {
        useDoublePrecision = false;
        useMixedPrecision = true;
    }
    else if (precision == "double") {
        useDoublePrecision = true;
        useMixedPrecision = false;
    }
    else
96
        throw OpenMMException("Illegal value for Precision: "+precision);
97
    try {
98
        contextIndex = platformData.contexts.size();
99
100
        std::vector<cl::Platform> platforms;
        cl::Platform::get(&platforms);
101
102
        if (platformIndex < -1 || platformIndex >= (int) platforms.size())
            throw OpenMMException("Illegal value for OpenCLPlatformIndex: "+intToString(platformIndex));
103
104
        if (platforms.size() > 1 && platformIndex == -1 && deviceIndex != -1)
            throw OpenMMException("Specified DeviceIndex but not OpenCLPlatformIndex.  When multiple platforms are available, a platform index is needed to specify a device.");
Robert McGibbon's avatar
Robert McGibbon committed
105
        const int minThreadBlockSize = 32;
106

Robert McGibbon's avatar
Robert McGibbon committed
107
108
109
        int bestSpeed = -1;
        int bestDevice = -1;
        int bestPlatform = -1;
110
        bool bestSupported = false;
Robert McGibbon's avatar
Robert McGibbon committed
111
        for (int j = 0; j < platforms.size(); j++) {
112
113
            // If they supplied a valid platformIndex, we only look through that platform
            if (j != platformIndex && platformIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
114
115
                continue;

116
117
118
119
            // Always prefer a supported platform over an unsupported one.
            bool supported = isSupported(platforms[j]);
            if (!supported && bestSupported)
                continue;
Robert McGibbon's avatar
Robert McGibbon committed
120
121
            string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>();
            vector<cl::Device> devices;
122
123
124
125
126
127
128
            try {
                platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices);
            }
            catch (...) {
                // There are no devices available for this platform.
                continue;
            }
129
            if (deviceIndex < -1 || deviceIndex >= (int) devices.size())
130
                throw OpenMMException("Illegal value for DeviceIndex: "+intToString(deviceIndex));
Robert McGibbon's avatar
Robert McGibbon committed
131

132
            for (int i = 0; i < (int) devices.size(); i++) {
133
134
                // If they supplied a valid deviceIndex, we only look through that one
                if (i != deviceIndex && deviceIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
135
                    continue;
136
137
                if (platformVendor == "Apple" && (devices[i].getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU))
                    continue; // The CPU device on OS X won't work correctly.
138
139
140
141
142
                if (useMixedPrecision || useDoublePrecision) {
                    bool supportsDouble = (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
                    if (!supportsDouble)
                        continue; // This device does not support double precision.
                }
143
                int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0];
144
145
146
147
148
                int processingElementsPerComputeUnit = 8;
                if (devices[i].getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                    processingElementsPerComputeUnit = 1;
                }
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
149
150
151
152
                    cl_uint computeCapabilityMajor;
                    clGetDeviceInfo(devices[i](), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
                    processingElementsPerComputeUnit = (computeCapabilityMajor < 2 ? 8 : 32);
                }
153
154
155
156
157
158
159
160
161
162
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime (it may be an older runtime,
                    // or the CPU device) so still have to check for errors.
                    try {
                        processingElementsPerComputeUnit =
                            // AMD GPUs either have a single VLIW SIMD or multiple scalar SIMDs.
                            // The SIMD width is the number of threads the SIMD executes per cycle.
                            // This will be less than the wavefront width since it takes several
                            // cycles to execute the full wavefront.
                            // The SIMD instruction width is the VLIW instruction width (or 1 for scalar),
163
                            // this is the number of ALUs that can be executing per instruction per thread.
164
165
166
167
168
169
170
171
172
173
174
                            devices[i].getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_WIDTH_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_INSTRUCTION_WIDTH_AMD>();
                        // Just in case any of the queries return 0.
                        if (processingElementsPerComputeUnit <= 0)
                            processingElementsPerComputeUnit = 1;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the queries so use default.
                    }
                }
175
                int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*processingElementsPerComputeUnit*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
176
                if (maxSize >= minThreadBlockSize && (speed > bestSpeed || (supported && !bestSupported))) {
Robert McGibbon's avatar
Robert McGibbon committed
177
                    bestDevice = i;
178
                    bestSpeed = speed;
Robert McGibbon's avatar
Robert McGibbon committed
179
                    bestPlatform = j;
180
                    bestSupported = supported;
181
                }
182
            }
183
        }
Robert McGibbon's avatar
Robert McGibbon committed
184
185
186
187
188

        if (bestPlatform == -1)
            throw OpenMMException("No compatible OpenCL platform is available");

        if (bestDevice == -1)
189
            throw OpenMMException("No compatible OpenCL device is available");
190
191
192
        
        if (!bestSupported)
            cout << "WARNING: Using an unsupported OpenCL implementation.  Results may be incorrect." << endl;
Robert McGibbon's avatar
Robert McGibbon committed
193
194
195

        vector<cl::Device> devices;
        platforms[bestPlatform].getDevices(CL_DEVICE_TYPE_ALL, &devices);
Robert McGibbon's avatar
Robert McGibbon committed
196
        string platformVendor = platforms[bestPlatform].getInfo<CL_PLATFORM_VENDOR>();
Robert McGibbon's avatar
Robert McGibbon committed
197
198
199
        device = devices[bestDevice];

        this->deviceIndex = bestDevice;
Robert McGibbon's avatar
Robert McGibbon committed
200
        this->platformIndex = bestPlatform;
201
        if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize)
202
            throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
203
        compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize);
Peter Eastman's avatar
Peter Eastman committed
204
        if (platformVendor.size() >= 5 && platformVendor.substr(0, 5) == "Intel")
205
206
            defaultOptimizationOptions = "";
        else
207
            defaultOptimizationOptions = "-cl-mad-enable -cl-no-signed-zeros";
208
        supports64BitGlobalAtomics = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_int64_base_atomics") != string::npos);
209
        supportsDoublePrecision = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
210
211
        if ((useDoublePrecision || useMixedPrecision) && !supportsDoublePrecision)
            throw OpenMMException("This device does not support double precision");
212
        string vendor = device.getInfo<CL_DEVICE_VENDOR>();
213
        int numThreadBlocksPerComputeUnit = 6;
214
        if (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA") {
215
            compilationDefines["WARPS_ARE_ATOMIC"] = "";
216
            simdWidth = 32;
217
218
            if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
                // Compute level 1.2 and later Nvidia GPUs support 64 bit atomics, even though they don't list the
219
220
                // proper extension as supported.  We only use them on compute level 2.0 or later, since they're very
                // slow on earlier GPUs.
221

222
                cl_uint computeCapabilityMajor;
223
                clGetDeviceInfo(device(), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
224
                if (computeCapabilityMajor > 1)
225
                    supports64BitGlobalAtomics = true;
226
227
228
229
230
231
232
                if (computeCapabilityMajor == 5) {
                    // Workaround for a bug in Maxwell on CUDA 6.x.

                    string platformVersion = platforms[bestPlatform].getInfo<CL_PLATFORM_VERSION>();
                    if (platformVersion.find("CUDA 6") != string::npos)
                        supports64BitGlobalAtomics = false;
                }
233
            }
234
        }
235
        else if (vendor.size() >= 28 && vendor.substr(0, 28) == "Advanced Micro Devices, Inc.") {
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            if (device.getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                /// \todo Is 6 a good value for the OpenCL CPU device?
                // numThreadBlocksPerComputeUnit = ?;
                simdWidth = 1;
            }
            else {
                bool amdPostSdk2_4 = false;
                // Default to 1 which will use the default kernels.
                simdWidth = 1;
                if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime so still have to
                    // check for errors.
                    try {
                        // Must catch cl:Error as will fail if runtime does not support queries.

                        cl_uint simdPerComputeUnit = device.getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>();
252
253
                        simdWidth = device.getInfo<CL_DEVICE_WAVEFRONT_WIDTH_AMD>();

254
255
256
                        // If the GPU has multiple SIMDs per compute unit then it is uses the scalar instruction
                        // set instead of the VLIW instruction set. It therefore needs more thread blocks per
                        // compute unit to hide memory latency.
Peter Eastman's avatar
Peter Eastman committed
257
258
259
260
261
262
                        if (simdPerComputeUnit > 1) {
                            if (simdWidth == 32)
                                numThreadBlocksPerComputeUnit = 6*simdPerComputeUnit; // Navi seems to like more thread blocks than older GPUs
                            else
                                numThreadBlocksPerComputeUnit = 4*simdPerComputeUnit;
                        }
263
264
265
266
267
268
269
270
271
272
273
274
275

                        // If the queries are supported then must be newer than SDK 2.4.
                        amdPostSdk2_4 = true;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the query so is unlikely to be the newer scalar GPU.
                        // Stay with the default simdWidth and numThreadBlocksPerComputeUnit.
                    }
                }
                // AMD APP SDK 2.4 has a performance problem with atomics. Enable the work around. This is fixed after SDK 2.4.
                if (!amdPostSdk2_4)
                    compilationDefines["AMD_ATOMIC_WORK_AROUND"] = "";
            }
276
        }
277
278
        else
            simdWidth = 1;
279
        if (supports64BitGlobalAtomics)
280
            compilationDefines["SUPPORTS_64_BIT_ATOMICS"] = "";
281
282
        if (supportsDoublePrecision)
            compilationDefines["SUPPORTS_DOUBLE_PRECISION"] = "";
283
        if (simdWidth >= 32)
Peter Eastman's avatar
Peter Eastman committed
284
            compilationDefines["SYNC_WARPS"] = "mem_fence(CLK_LOCAL_MEM_FENCE)";
285
286
        else
            compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)";
287
288
        vector<cl::Device> contextDevices;
        contextDevices.push_back(device);
Robert McGibbon's avatar
Robert McGibbon committed
289
        cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
290
291
292
293
294
295
296
297
        if (originalContext == NULL) {
            context = cl::Context(contextDevices, cprops, errorCallback);
            defaultQueue = cl::CommandQueue(context, device);
        }
        else {
            context = originalContext->context;
            defaultQueue = originalContext->defaultQueue;
        }
298
        currentQueue = defaultQueue;
Peter Eastman's avatar
Peter Eastman committed
299
300
        numAtoms = system.getNumParticles();
        paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
301
        numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
302
        numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
303
        if (useDoublePrecision) {
peastman's avatar
peastman committed
304
305
            posq.initialize<mm_double4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_double4>(*this, paddedNumAtoms, "velm");
306
307
            compilationDefines["USE_DOUBLE_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_double4";
308
309
310
            compilationDefines["make_real2"] = "make_double2";
            compilationDefines["make_real3"] = "make_double3";
            compilationDefines["make_real4"] = "make_double4";
311
            compilationDefines["convert_mixed4"] = "convert_double4";
312
313
314
            compilationDefines["make_mixed2"] = "make_double2";
            compilationDefines["make_mixed3"] = "make_double3";
            compilationDefines["make_mixed4"] = "make_double4";
315
316
        }
        else if (useMixedPrecision) {
peastman's avatar
peastman committed
317
318
319
            posq.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            posqCorrection.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_double4>(*this, paddedNumAtoms, "velm");
320
321
            compilationDefines["USE_MIXED_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_float4";
322
323
324
            compilationDefines["make_real2"] = "make_float2";
            compilationDefines["make_real3"] = "make_float3";
            compilationDefines["make_real4"] = "make_float4";
325
            compilationDefines["convert_mixed4"] = "convert_double4";
326
327
328
            compilationDefines["make_mixed2"] = "make_double2";
            compilationDefines["make_mixed3"] = "make_double3";
            compilationDefines["make_mixed4"] = "make_double4";
329
330
        }
        else {
peastman's avatar
peastman committed
331
332
            posq.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_float4>(*this, paddedNumAtoms, "velm");
333
            compilationDefines["convert_real4"] = "convert_float4";
334
335
336
            compilationDefines["make_real2"] = "make_float2";
            compilationDefines["make_real3"] = "make_float3";
            compilationDefines["make_real4"] = "make_float4";
337
            compilationDefines["convert_mixed4"] = "convert_float4";
338
339
340
            compilationDefines["make_mixed2"] = "make_float2";
            compilationDefines["make_mixed3"] = "make_float3";
            compilationDefines["make_mixed4"] = "make_float4";
341
        }
342
        longForceBuffer.initialize<cl_long>(*this, 3*paddedNumAtoms, "longForceBuffer");
343
        posCellOffsets.resize(paddedNumAtoms, mm_int4(0, 0, 0, 0));
344
345
346
347
348
        atomIndexDevice.initialize<cl_int>(*this, paddedNumAtoms, "atomIndexDevice");
        atomIndex.resize(paddedNumAtoms);
        for (int i = 0; i < paddedNumAtoms; ++i)
            atomIndex[i] = i;
        atomIndexDevice.upload(atomIndex);
349
350
351
352
353
    }
    catch (cl::Error err) {
        std::stringstream str;
        str<<"Error initializing context: "<<err.what()<<" ("<<err.err()<<")";
        throw OpenMMException(str.str());
354
    }
355
356
357

    // Create utility kernels that are used in multiple places.

Peter Eastman's avatar
Peter Eastman committed
358
    cl::Program utilities = createProgram(OpenCLKernelSources::utilities);
359
    clearBufferKernel = cl::Kernel(utilities, "clearBuffer");
360
361
362
    clearTwoBuffersKernel = cl::Kernel(utilities, "clearTwoBuffers");
    clearThreeBuffersKernel = cl::Kernel(utilities, "clearThreeBuffers");
    clearFourBuffersKernel = cl::Kernel(utilities, "clearFourBuffers");
363
364
    clearFiveBuffersKernel = cl::Kernel(utilities, "clearFiveBuffers");
    clearSixBuffersKernel = cl::Kernel(utilities, "clearSixBuffers");
365
    reduceReal4Kernel = cl::Kernel(utilities, "reduceReal4Buffer");
366
    reduceForcesKernel = cl::Kernel(utilities, "reduceForces");
Peter Eastman's avatar
Peter Eastman committed
367
    reduceEnergyKernel = cl::Kernel(utilities, "reduceEnergy");
368
    setChargesKernel = cl::Kernel(utilities, "setCharges");
369
370
371

    // Decide whether native_sqrt(), native_rsqrt(), and native_recip() are sufficiently accurate to use.

372
373
374
375
376
    if (!useDoublePrecision) {
        cl::Kernel accuracyKernel(utilities, "determineNativeAccuracy");
        OpenCLArray valuesArray(*this, 20, sizeof(mm_float8), "values");
        vector<mm_float8> values(valuesArray.getSize());
        float nextValue = 1e-4f;
peastman's avatar
peastman committed
377
378
        for (auto& val : values) {
            val.s0 = nextValue;
379
380
381
382
383
384
385
386
            nextValue *= (float) M_PI;
        }
        valuesArray.upload(values);
        accuracyKernel.setArg<cl::Buffer>(0, valuesArray.getDeviceBuffer());
        accuracyKernel.setArg<cl_int>(1, values.size());
        executeKernel(accuracyKernel, values.size());
        valuesArray.download(values);
        double maxSqrtError = 0.0, maxRsqrtError = 0.0, maxRecipError = 0.0, maxExpError = 0.0, maxLogError = 0.0;
peastman's avatar
peastman committed
387
388
        for (auto& val : values) {
            double v = val.s0;
389
            double correctSqrt = sqrt(v);
peastman's avatar
peastman committed
390
391
392
393
394
            maxSqrtError = max(maxSqrtError, fabs(correctSqrt-val.s1)/correctSqrt);
            maxRsqrtError = max(maxRsqrtError, fabs(1.0/correctSqrt-val.s2)*correctSqrt);
            maxRecipError = max(maxRecipError, fabs(1.0/v-val.s3)/val.s3);
            maxExpError = max(maxExpError, fabs(exp(v)-val.s4)/val.s4);
            maxLogError = max(maxLogError, fabs(log(v)-val.s5)/val.s5);
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        }
        compilationDefines["SQRT"] = (maxSqrtError < 1e-6) ? "native_sqrt" : "sqrt";
        compilationDefines["RSQRT"] = (maxRsqrtError < 1e-6) ? "native_rsqrt" : "rsqrt";
        compilationDefines["RECIP"] = (maxRecipError < 1e-6) ? "native_recip" : "1.0f/";
        compilationDefines["EXP"] = (maxExpError < 1e-6) ? "native_exp" : "exp";
        compilationDefines["LOG"] = (maxLogError < 1e-6) ? "native_log" : "log";
    }
    else {
        compilationDefines["SQRT"] = "sqrt";
        compilationDefines["RSQRT"] = "rsqrt";
        compilationDefines["RECIP"] = "1.0/";
        compilationDefines["EXP"] = "exp";
        compilationDefines["LOG"] = "log";
    }
409
410
411
412
413
414
415
    compilationDefines["POW"] = "pow";
    compilationDefines["COS"] = "cos";
    compilationDefines["SIN"] = "sin";
    compilationDefines["TAN"] = "tan";
    compilationDefines["ACOS"] = "acos";
    compilationDefines["ASIN"] = "asin";
    compilationDefines["ATAN"] = "atan";
416
417
    compilationDefines["ERF"] = "erf";
    compilationDefines["ERFC"] = "erfc";
418

419
    // Set defines for applying periodic boundary conditions.
420

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    Vec3 boxVectors[3];
    system.getDefaultPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
    boxIsTriclinic = (boxVectors[0][1] != 0.0 || boxVectors[0][2] != 0.0 ||
                      boxVectors[1][0] != 0.0 || boxVectors[1][2] != 0.0 ||
                      boxVectors[2][0] != 0.0 || boxVectors[2][1] != 0.0);
    if (boxIsTriclinic) {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "{"
            "real scale3 = floor(delta.z*invPeriodicBoxSize.z+0.5f); \\\n"
            "delta.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(delta.y*invPeriodicBoxSize.y+0.5f); \\\n"
            "delta.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(delta.x*invPeriodicBoxSize.x+0.5f); \\\n"
            "delta.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "{"
            "real scale3 = floor(pos.z*invPeriodicBoxSize.z); \\\n"
            "pos.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(pos.y*invPeriodicBoxSize.y); \\\n"
            "pos.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(pos.x*invPeriodicBoxSize.x); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "real scale3 = floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f); \\\n"
            "pos.x -= scale3*periodicBoxVecZ.x; \\\n"
            "pos.y -= scale3*periodicBoxVecZ.y; \\\n"
            "pos.z -= scale3*periodicBoxVecZ.z; \\\n"
            "real scale2 = floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f); \\\n"
            "pos.x -= scale2*periodicBoxVecY.x; \\\n"
            "pos.y -= scale2*periodicBoxVecY.y; \\\n"
            "real scale1 = floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
    }
    else {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "pos.xyz -= floor(pos.xyz*invPeriodicBoxSize.xyz)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "pos.x -= floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; \\\n"
            "pos.y -= floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y; \\\n"
            "pos.z -= floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;}";
    }

467
    // Create utilities objects.
468

469
470
    bonded = new OpenCLBondedUtilities(*this);
    nonbonded = new OpenCLNonbondedUtilities(*this);
Peter Eastman's avatar
Peter Eastman committed
471
    integration = new OpenCLIntegrationUtilities(*this, system);
472
    expression = new OpenCLExpressionUtilities(*this);
473
474
475
}

OpenCLContext::~OpenCLContext() {
peastman's avatar
peastman committed
476
477
478
479
480
481
482
483
    for (auto force : forces)
        delete force;
    for (auto listener : reorderListeners)
        delete listener;
    for (auto computation : preComputations)
        delete computation;
    for (auto computation : postComputations)
        delete computation;
484
485
    if (pinnedBuffer != NULL)
        delete pinnedBuffer;
486
487
    if (integration != NULL)
        delete integration;
488
489
    if (expression != NULL)
        delete expression;
Peter Eastman's avatar
Peter Eastman committed
490
491
    if (bonded != NULL)
        delete bonded;
492
493
    if (nonbonded != NULL)
        delete nonbonded;
494
495
}

496
void OpenCLContext::initialize() {
Peter Eastman's avatar
Peter Eastman committed
497
    bonded->initialize(system);
498
    numForceBuffers = std::max(numForceBuffers, (int) platformData.contexts.size());
Peter Eastman's avatar
Peter Eastman committed
499
    numForceBuffers = std::max(numForceBuffers, bonded->getNumForceBuffers());
500
    numForceBuffers = std::max(numForceBuffers, nonbonded->getNumForceBuffers());
501
    int energyBufferSize = max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers());
502
    if (useDoublePrecision) {
peastman's avatar
peastman committed
503
504
505
506
        forceBuffers.initialize<mm_double4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_double4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_double>(*this, energyBufferSize, "energyBuffer");
        energySum.initialize<cl_double>(*this, 1, "energySum");
507
    }
Peter Eastman's avatar
Peter Eastman committed
508
    else if (useMixedPrecision) {
peastman's avatar
peastman committed
509
510
511
512
        forceBuffers.initialize<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_float4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_double>(*this, energyBufferSize, "energyBuffer");
        energySum.initialize<cl_double>(*this, 1, "energySum");
Peter Eastman's avatar
Peter Eastman committed
513
514
    }
    else {
peastman's avatar
peastman committed
515
516
517
518
        forceBuffers.initialize<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_float4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_float>(*this, energyBufferSize, "energyBuffer");
        energySum.initialize<cl_float>(*this, 1, "energySum");
519
    }
520
521
522
523
    reduceForcesKernel.setArg<cl::Buffer>(0, longForceBuffer.getDeviceBuffer());
    reduceForcesKernel.setArg<cl::Buffer>(1, forceBuffers.getDeviceBuffer());
    reduceForcesKernel.setArg<cl_int>(2, paddedNumAtoms);
    reduceForcesKernel.setArg<cl_int>(3, numForceBuffers);
524
    addAutoclearBuffer(longForceBuffer);
peastman's avatar
peastman committed
525
526
    addAutoclearBuffer(forceBuffers);
    addAutoclearBuffer(energyBuffer);
527
528
529
    int numEnergyParamDerivs = energyParamDerivNames.size();
    if (numEnergyParamDerivs > 0) {
        if (useDoublePrecision || useMixedPrecision)
peastman's avatar
peastman committed
530
            energyParamDerivBuffer.initialize<cl_double>(*this, numEnergyParamDerivs*energyBufferSize, "energyParamDerivBuffer");
531
        else
peastman's avatar
peastman committed
532
533
            energyParamDerivBuffer.initialize<cl_float>(*this, numEnergyParamDerivs*energyBufferSize, "energyParamDerivBuffer");
        addAutoclearBuffer(energyParamDerivBuffer);
534
    }
535
536
537
    int bufferBytes = max(max(velm.getSize()*velm.getElementSize(),
            energyBufferSize*energyBuffer.getElementSize()),
            longForceBuffer.getSize()*longForceBuffer.getElementSize());
538
    pinnedBuffer = new cl::Buffer(context, CL_MEM_ALLOC_HOST_PTR, bufferBytes);
539
    pinnedMemory = currentQueue.enqueueMapBuffer(*pinnedBuffer, CL_TRUE, CL_MAP_READ | CL_MAP_WRITE, 0, bufferBytes);
540
541
542
543
544
545
546
    for (int i = 0; i < numAtoms; i++) {
        double mass = system.getParticleMass(i);
        if (useDoublePrecision || useMixedPrecision)
            ((mm_double4*) pinnedMemory)[i] = mm_double4(0.0, 0.0, 0.0, mass == 0.0 ? 0.0 : 1.0/mass);
        else
            ((mm_float4*) pinnedMemory)[i] = mm_float4(0.0f, 0.0f, 0.0f, mass == 0.0 ? 0.0f : (cl_float) (1.0/mass));
    }
peastman's avatar
peastman committed
547
    velm.upload(pinnedMemory);
548
    findMoleculeGroups();
549
    nonbonded->initialize(system);
550
551
}

552
553
void OpenCLContext::initializeContexts() {
    getPlatformData().initializeContexts(system);
554
555
}

556
557
558
559
560
void OpenCLContext::addForce(ComputeForceInfo* force) {
    ComputeContext::addForce(force);
    OpenCLForceInfo* clinfo = dynamic_cast<OpenCLForceInfo*>(force);
    if (clinfo != NULL)
        requestForceBuffers(clinfo->getRequiredForceBuffers());
561
562
}

563
564
void OpenCLContext::requestForceBuffers(int minBuffers) {
    numForceBuffers = std::max(numForceBuffers, minBuffers);
565
566
}

567
568
cl::Program OpenCLContext::createProgram(const string source, const char* optimizationFlags) {
    return createProgram(source, map<string, string>(), optimizationFlags);
569
570
}

571
cl::Program OpenCLContext::createProgram(const string source, const map<string, string>& defines, const char* optimizationFlags) {
Peter Eastman's avatar
Peter Eastman committed
572
    string options = (optimizationFlags == NULL ? defaultOptimizationOptions : string(optimizationFlags));
573
574
575
    stringstream src;
    if (!options.empty())
        src << "// Compilation Options: " << options << endl << endl;
peastman's avatar
peastman committed
576
    for (auto& pair : compilationDefines) {
577
578
579
580
581
582
583
        // Query defines to avoid duplicate variables
        if (defines.find(pair.first) == defines.end()) {
            src << "#define " << pair.first;
            if (!pair.second.empty())
                src << " " << pair.second;
            src << endl;
        }
584
585
586
    }
    if (!compilationDefines.empty())
        src << endl;
587
588
589
590
591
    if (supportsDoublePrecision)
        src << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
    if (useDoublePrecision) {
        src << "typedef double real;\n";
        src << "typedef double2 real2;\n";
592
        src << "typedef double3 real3;\n";
593
594
595
596
597
        src << "typedef double4 real4;\n";
    }
    else {
        src << "typedef float real;\n";
        src << "typedef float2 real2;\n";
598
        src << "typedef float3 real3;\n";
599
600
601
602
603
        src << "typedef float4 real4;\n";
    }
    if (useDoublePrecision || useMixedPrecision) {
        src << "typedef double mixed;\n";
        src << "typedef double2 mixed2;\n";
604
        src << "typedef double3 mixed3;\n";
605
606
607
608
609
        src << "typedef double4 mixed4;\n";
    }
    else {
        src << "typedef float mixed;\n";
        src << "typedef float2 mixed2;\n";
610
        src << "typedef float3 mixed3;\n";
611
612
        src << "typedef float4 mixed4;\n";
    }
613
    src << OpenCLKernelSources::common << endl;
peastman's avatar
peastman committed
614
615
616
617
    for (auto& pair : defines) {
        src << "#define " << pair.first;
        if (!pair.second.empty())
            src << " " << pair.second;
618
619
620
621
622
        src << endl;
    }
    if (!defines.empty())
        src << endl;
    src << source << endl;
623
    cl::Program::Sources sources({src.str()});
624
625
    cl::Program program(context, sources);
    try {
626
        program.build(vector<cl::Device>(1, device), options.c_str());
627
628
629
630
631
632
    } catch (cl::Error err) {
        throw OpenMMException("Error compiling kernel: "+program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(device));
    }
    return program;
}

633
634
635
636
637
638
639
640
641
642
643
644
cl::CommandQueue& OpenCLContext::getQueue() {
    return currentQueue;
}

void OpenCLContext::setQueue(cl::CommandQueue& queue) {
    currentQueue = queue;
}

void OpenCLContext::restoreDefaultQueue() {
    currentQueue = defaultQueue;
}

645
646
647
648
649
650
OpenCLArray* OpenCLContext::createArray() {
    return new OpenCLArray();
}

ComputeEvent OpenCLContext::createEvent() {
    return shared_ptr<ComputeEventImpl>(new OpenCLEvent(*this));
651
652
}

653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
ComputeProgram OpenCLContext::compileProgram(const std::string source, const std::map<std::string, std::string>& defines) {
    cl::Program program = createProgram(source, defines);
    return shared_ptr<ComputeProgramImpl>(new OpenCLProgram(*this, program));
}

OpenCLArray& OpenCLContext::unwrap(ArrayInterface& array) const {
    OpenCLArray* clarray;
    ComputeArray* wrapper = dynamic_cast<ComputeArray*>(&array);
    if (wrapper != NULL)
        clarray = dynamic_cast<OpenCLArray*>(&wrapper->getArray());
    else
        clarray = dynamic_cast<OpenCLArray*>(&array);
    if (clarray == NULL)
        throw OpenMMException("Array argument is not an OpenCLArray");
    return *clarray;
668
669
}

670
671
672
673
void OpenCLContext::executeKernel(cl::Kernel& kernel, int workUnits, int blockSize) {
    if (blockSize == -1)
        blockSize = ThreadBlockSize;
    int size = std::min((workUnits+blockSize-1)/blockSize, numThreadBlocks)*blockSize;
674
    try {
675
        currentQueue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(size), cl::NDRange(blockSize));
676
677
678
    }
    catch (cl::Error err) {
        stringstream str;
679
        str<<"Error invoking kernel "<<kernel.getInfo<CL_KERNEL_FUNCTION_NAME>()<<": "<<err.what()<<" ("<<err.err()<<")";
680
681
682
683
        throw OpenMMException(str.str());
    }
}

684
685
void OpenCLContext::clearBuffer(ArrayInterface& array) {
    clearBuffer(unwrap(array).getDeviceBuffer(), array.getSize()*array.getElementSize());
686
687
}

688
void OpenCLContext::clearBuffer(cl::Memory& memory, int size) {
689
    int words = size/4;
690
    clearBufferKernel.setArg<cl::Memory>(0, memory);
691
692
693
694
    clearBufferKernel.setArg<cl_int>(1, words);
    executeKernel(clearBufferKernel, words, 128);
}

695
696
void OpenCLContext::addAutoclearBuffer(ArrayInterface& array) {
    addAutoclearBuffer(unwrap(array).getDeviceBuffer(), array.getSize()*array.getElementSize());
697
698
}

699
700
void OpenCLContext::addAutoclearBuffer(cl::Memory& memory, int size) {
    autoclearBuffers.push_back(&memory);
701
    autoclearBufferSizes.push_back(size/4);
702
703
704
705
706
}

void OpenCLContext::clearAutoclearBuffers() {
    int base = 0;
    int total = autoclearBufferSizes.size();
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
    while (total-base >= 6) {
        clearSixBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearSixBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearSixBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearSixBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearSixBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearSixBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearSixBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearSixBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearSixBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearSixBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        clearSixBuffersKernel.setArg<cl::Memory>(10, *autoclearBuffers[base+5]);
        clearSixBuffersKernel.setArg<cl_int>(11, autoclearBufferSizes[base+5]);
        executeKernel(clearSixBuffersKernel, max(max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), autoclearBufferSizes[base+5]), 128);
        base += 6;
    }
    if (total-base == 5) {
        clearFiveBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFiveBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFiveBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFiveBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFiveBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFiveBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFiveBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFiveBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearFiveBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearFiveBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        executeKernel(clearFiveBuffersKernel, max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), 128);
    }
    else if (total-base == 4) {
737
738
739
740
741
742
743
744
        clearFourBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFourBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFourBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFourBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFourBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFourBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFourBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFourBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
745
        executeKernel(clearFourBuffersKernel, max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), 128);
746
    }
747
    else if (total-base == 3) {
748
749
750
751
752
753
        clearThreeBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearThreeBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearThreeBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearThreeBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearThreeBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearThreeBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
754
        executeKernel(clearThreeBuffersKernel, max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), 128);
755
756
757
758
759
760
    }
    else if (total-base == 2) {
        clearTwoBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearTwoBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearTwoBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearTwoBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
761
        executeKernel(clearTwoBuffersKernel, max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), 128);
762
763
    }
    else if (total-base == 1) {
764
        clearBuffer(*autoclearBuffers[base], autoclearBufferSizes[base]*4);
765
766
767
    }
}

768
void OpenCLContext::reduceForces() {
769
    executeKernel(reduceForcesKernel, paddedNumAtoms, 128);
770
771
}

772
void OpenCLContext::reduceBuffer(OpenCLArray& array, int numBuffers) {
773
    int bufferSize = array.getSize()/numBuffers;
774
775
776
777
    reduceReal4Kernel.setArg<cl::Buffer>(0, array.getDeviceBuffer());
    reduceReal4Kernel.setArg<cl_int>(1, bufferSize);
    reduceReal4Kernel.setArg<cl_int>(2, numBuffers);
    executeKernel(reduceReal4Kernel, bufferSize, 128);
778
}
779

Peter Eastman's avatar
Peter Eastman committed
780
781
782
783
double OpenCLContext::reduceEnergy() {
    int workGroupSize  = device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>();
    if (workGroupSize > 512)
        workGroupSize = 512;
peastman's avatar
peastman committed
784
785
786
    reduceEnergyKernel.setArg<cl::Buffer>(0, energyBuffer.getDeviceBuffer());
    reduceEnergyKernel.setArg<cl::Buffer>(1, energySum.getDeviceBuffer());
    reduceEnergyKernel.setArg<cl_int>(2, energyBuffer.getSize());
Peter Eastman's avatar
Peter Eastman committed
787
    reduceEnergyKernel.setArg<cl_int>(3, workGroupSize);
peastman's avatar
peastman committed
788
    reduceEnergyKernel.setArg(4, workGroupSize*energyBuffer.getElementSize(), NULL);
Peter Eastman's avatar
Peter Eastman committed
789
790
791
    executeKernel(reduceEnergyKernel, workGroupSize, workGroupSize);
    if (getUseDoublePrecision() || getUseMixedPrecision()) {
        double energy;
peastman's avatar
peastman committed
792
        energySum.download(&energy);
Peter Eastman's avatar
Peter Eastman committed
793
794
795
796
        return energy;
    }
    else {
        float energy;
peastman's avatar
peastman committed
797
        energySum.download(&energy);
Peter Eastman's avatar
Peter Eastman committed
798
799
800
801
        return energy;
    }
}

802
void OpenCLContext::setCharges(const vector<double>& charges) {
peastman's avatar
peastman committed
803
804
    if (!chargeBuffer.isInitialized())
        chargeBuffer.initialize(*this, numAtoms, useDoublePrecision ? sizeof(double) : sizeof(float), "chargeBuffer");
peastman's avatar
peastman committed
805
806
807
    vector<double> c(numAtoms);
    for (int i = 0; i < numAtoms; i++)
        c[i] = charges[i];
808
    chargeBuffer.upload(c, true);
peastman's avatar
peastman committed
809
810
811
    setChargesKernel.setArg<cl::Buffer>(0, chargeBuffer.getDeviceBuffer());
    setChargesKernel.setArg<cl::Buffer>(1, posq.getDeviceBuffer());
    setChargesKernel.setArg<cl::Buffer>(2, atomIndexDevice.getDeviceBuffer());
812
813
814
815
    setChargesKernel.setArg<cl_int>(3, numAtoms);
    executeKernel(setChargesKernel, numAtoms);
}

816
817
818
819
820
821
bool OpenCLContext::requestPosqCharges() {
    bool allow = !hasAssignedPosqCharges;
    hasAssignedPosqCharges = true;
    return allow;
}

822
823
824
825
826
827
828
829
830
void OpenCLContext::addEnergyParameterDerivative(const string& param) {
    // See if this parameter has already been registered.
    
    for (int i = 0; i < energyParamDerivNames.size(); i++)
        if (param == energyParamDerivNames[i])
            return;
    energyParamDerivNames.push_back(param);
}

831
832
void OpenCLContext::flushQueue() {
    getQueue().flush();
833
}