OpenCLContext.cpp 40.8 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2020 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
28
29
30
#ifdef WIN32
  #define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include <cmath>
31
32
#include "OpenCLContext.h"
#include "OpenCLArray.h"
Peter Eastman's avatar
Peter Eastman committed
33
#include "OpenCLBondedUtilities.h"
34
#include "OpenCLEvent.h"
35
#include "OpenCLForceInfo.h"
36
#include "OpenCLIntegrationUtilities.h"
37
#include "OpenCLKernelSources.h"
38
#include "OpenCLNonbondedUtilities.h"
39
40
#include "OpenCLProgram.h"
#include "openmm/common/ComputeArray.h"
41
#include "openmm/Platform.h"
42
#include "openmm/System.h"
43
#include "openmm/VirtualSite.h"
44
#include "openmm/internal/ContextImpl.h"
Peter Eastman's avatar
Peter Eastman committed
45
#include <algorithm>
46
47
#include <fstream>
#include <iostream>
48
#include <set>
49
#include <sstream>
50
#include <typeinfo>
51
52

using namespace OpenMM;
53
using namespace std;
54

55
56
57
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV 0x4000
#endif
58
59
60
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV 0x4001
#endif
61

62
63
64
const int OpenCLContext::ThreadBlockSize = 64;
const int OpenCLContext::TileSize = 32;

65
static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_info, size_t cb, void* user_data) {
66
67
68
    string skip = "OpenCL Build Warning : Compiler build log:";
    if (strncmp(errinfo, skip.c_str(), skip.length()) == 0)
        return; // OS X Lion insists on calling this for every build warning, even though they aren't errors.
69
70
71
    std::cerr << "OpenCL internal error: " << errinfo << std::endl;
}

72
73
74
75
76
77
78
79
static bool isSupported(cl::Platform platform) {
    string vendor = platform.getInfo<CL_PLATFORM_VENDOR>();
    return (vendor.find("NVIDIA") == 0 ||
            vendor.find("Advanced Micro Devices") == 0 ||
            vendor.find("Apple") == 0 ||
            vendor.find("Intel") == 0);
}

80
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, const string& precision, OpenCLPlatform::PlatformData& platformData, OpenCLContext* originalContext) :
81
82
        ComputeContext(system), platformData(platformData), numForceBuffers(0), hasAssignedPosqCharges(false),
        integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL) {
83
84
85
86
87
88
89
90
91
92
93
94
95
    if (precision == "single") {
        useDoublePrecision = false;
        useMixedPrecision = false;
    }
    else if (precision == "mixed") {
        useDoublePrecision = false;
        useMixedPrecision = true;
    }
    else if (precision == "double") {
        useDoublePrecision = true;
        useMixedPrecision = false;
    }
    else
96
        throw OpenMMException("Illegal value for Precision: "+precision);
97
    try {
98
        contextIndex = platformData.contexts.size();
99
100
        std::vector<cl::Platform> platforms;
        cl::Platform::get(&platforms);
101
102
        if (platformIndex < -1 || platformIndex >= (int) platforms.size())
            throw OpenMMException("Illegal value for OpenCLPlatformIndex: "+intToString(platformIndex));
103
104
        if (platforms.size() > 1 && platformIndex == -1 && deviceIndex != -1)
            throw OpenMMException("Specified DeviceIndex but not OpenCLPlatformIndex.  When multiple platforms are available, a platform index is needed to specify a device.");
Robert McGibbon's avatar
Robert McGibbon committed
105
        const int minThreadBlockSize = 32;
106

Robert McGibbon's avatar
Robert McGibbon committed
107
108
109
        int bestSpeed = -1;
        int bestDevice = -1;
        int bestPlatform = -1;
110
        bool bestSupported = false;
Robert McGibbon's avatar
Robert McGibbon committed
111
        for (int j = 0; j < platforms.size(); j++) {
112
113
            // If they supplied a valid platformIndex, we only look through that platform
            if (j != platformIndex && platformIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
114
115
                continue;

116
117
118
119
            // Always prefer a supported platform over an unsupported one.
            bool supported = isSupported(platforms[j]);
            if (!supported && bestSupported)
                continue;
Robert McGibbon's avatar
Robert McGibbon committed
120
121
            string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>();
            vector<cl::Device> devices;
122
123
124
125
126
127
128
            try {
                platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices);
            }
            catch (...) {
                // There are no devices available for this platform.
                continue;
            }
129
            if (deviceIndex < -1 || deviceIndex >= (int) devices.size())
130
                throw OpenMMException("Illegal value for DeviceIndex: "+intToString(deviceIndex));
Robert McGibbon's avatar
Robert McGibbon committed
131

132
            for (int i = 0; i < (int) devices.size(); i++) {
133
134
                // If they supplied a valid deviceIndex, we only look through that one
                if (i != deviceIndex && deviceIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
135
                    continue;
136
137
                if (platformVendor == "Apple" && (devices[i].getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU))
                    continue; // The CPU device on OS X won't work correctly.
138
139
140
141
142
                if (useMixedPrecision || useDoublePrecision) {
                    bool supportsDouble = (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
                    if (!supportsDouble)
                        continue; // This device does not support double precision.
                }
143
                int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0];
144
145
146
147
148
                int processingElementsPerComputeUnit = 8;
                if (devices[i].getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                    processingElementsPerComputeUnit = 1;
                }
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
149
150
151
152
                    cl_uint computeCapabilityMajor;
                    clGetDeviceInfo(devices[i](), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
                    processingElementsPerComputeUnit = (computeCapabilityMajor < 2 ? 8 : 32);
                }
153
154
155
156
157
158
159
160
161
162
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime (it may be an older runtime,
                    // or the CPU device) so still have to check for errors.
                    try {
                        processingElementsPerComputeUnit =
                            // AMD GPUs either have a single VLIW SIMD or multiple scalar SIMDs.
                            // The SIMD width is the number of threads the SIMD executes per cycle.
                            // This will be less than the wavefront width since it takes several
                            // cycles to execute the full wavefront.
                            // The SIMD instruction width is the VLIW instruction width (or 1 for scalar),
163
                            // this is the number of ALUs that can be executing per instruction per thread.
164
165
166
167
168
169
170
171
172
173
174
                            devices[i].getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_WIDTH_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_INSTRUCTION_WIDTH_AMD>();
                        // Just in case any of the queries return 0.
                        if (processingElementsPerComputeUnit <= 0)
                            processingElementsPerComputeUnit = 1;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the queries so use default.
                    }
                }
175
                int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*processingElementsPerComputeUnit*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
176
                if (maxSize >= minThreadBlockSize && (speed > bestSpeed || (supported && !bestSupported))) {
Robert McGibbon's avatar
Robert McGibbon committed
177
                    bestDevice = i;
178
                    bestSpeed = speed;
Robert McGibbon's avatar
Robert McGibbon committed
179
                    bestPlatform = j;
180
                    bestSupported = supported;
181
                }
182
            }
183
        }
Robert McGibbon's avatar
Robert McGibbon committed
184
185
186
187
188

        if (bestPlatform == -1)
            throw OpenMMException("No compatible OpenCL platform is available");

        if (bestDevice == -1)
189
            throw OpenMMException("No compatible OpenCL device is available");
190
191
192
        
        if (!bestSupported)
            cout << "WARNING: Using an unsupported OpenCL implementation.  Results may be incorrect." << endl;
Robert McGibbon's avatar
Robert McGibbon committed
193
194
195

        vector<cl::Device> devices;
        platforms[bestPlatform].getDevices(CL_DEVICE_TYPE_ALL, &devices);
Robert McGibbon's avatar
Robert McGibbon committed
196
        string platformVendor = platforms[bestPlatform].getInfo<CL_PLATFORM_VENDOR>();
Robert McGibbon's avatar
Robert McGibbon committed
197
198
199
        device = devices[bestDevice];

        this->deviceIndex = bestDevice;
Robert McGibbon's avatar
Robert McGibbon committed
200
        this->platformIndex = bestPlatform;
201
        if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize)
202
            throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
203
        compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize);
Peter Eastman's avatar
Peter Eastman committed
204
        if (platformVendor.size() >= 5 && platformVendor.substr(0, 5) == "Intel")
205
206
            defaultOptimizationOptions = "";
        else
207
            defaultOptimizationOptions = "-cl-mad-enable -cl-no-signed-zeros";
208
        supports64BitGlobalAtomics = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_int64_base_atomics") != string::npos);
209
        supportsDoublePrecision = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
210
211
        if ((useDoublePrecision || useMixedPrecision) && !supportsDoublePrecision)
            throw OpenMMException("This device does not support double precision");
212
        string vendor = device.getInfo<CL_DEVICE_VENDOR>();
213
        int numThreadBlocksPerComputeUnit = 6;
214
        if (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA") {
215
            compilationDefines["WARPS_ARE_ATOMIC"] = "";
216
            simdWidth = 32;
217
218
            if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
                // Compute level 1.2 and later Nvidia GPUs support 64 bit atomics, even though they don't list the
219
220
                // proper extension as supported.  We only use them on compute level 2.0 or later, since they're very
                // slow on earlier GPUs.
221

222
                cl_uint computeCapabilityMajor;
223
                clGetDeviceInfo(device(), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
224
                if (computeCapabilityMajor > 1)
225
                    supports64BitGlobalAtomics = true;
226
227
228
229
230
231
232
                if (computeCapabilityMajor == 5) {
                    // Workaround for a bug in Maxwell on CUDA 6.x.

                    string platformVersion = platforms[bestPlatform].getInfo<CL_PLATFORM_VERSION>();
                    if (platformVersion.find("CUDA 6") != string::npos)
                        supports64BitGlobalAtomics = false;
                }
233
            }
234
        }
235
        else if (vendor.size() >= 28 && vendor.substr(0, 28) == "Advanced Micro Devices, Inc.") {
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            if (device.getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                /// \todo Is 6 a good value for the OpenCL CPU device?
                // numThreadBlocksPerComputeUnit = ?;
                simdWidth = 1;
            }
            else {
                bool amdPostSdk2_4 = false;
                // Default to 1 which will use the default kernels.
                simdWidth = 1;
                if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime so still have to
                    // check for errors.
                    try {
                        // Must catch cl:Error as will fail if runtime does not support queries.

                        cl_uint simdPerComputeUnit = device.getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>();
252
253
                        simdWidth = device.getInfo<CL_DEVICE_WAVEFRONT_WIDTH_AMD>();

254
255
256
                        // If the GPU has multiple SIMDs per compute unit then it is uses the scalar instruction
                        // set instead of the VLIW instruction set. It therefore needs more thread blocks per
                        // compute unit to hide memory latency.
Peter Eastman's avatar
Peter Eastman committed
257
258
259
260
261
262
                        if (simdPerComputeUnit > 1) {
                            if (simdWidth == 32)
                                numThreadBlocksPerComputeUnit = 6*simdPerComputeUnit; // Navi seems to like more thread blocks than older GPUs
                            else
                                numThreadBlocksPerComputeUnit = 4*simdPerComputeUnit;
                        }
263
264
265
266
267
268
269
270
271
272
273
274
275

                        // If the queries are supported then must be newer than SDK 2.4.
                        amdPostSdk2_4 = true;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the query so is unlikely to be the newer scalar GPU.
                        // Stay with the default simdWidth and numThreadBlocksPerComputeUnit.
                    }
                }
                // AMD APP SDK 2.4 has a performance problem with atomics. Enable the work around. This is fixed after SDK 2.4.
                if (!amdPostSdk2_4)
                    compilationDefines["AMD_ATOMIC_WORK_AROUND"] = "";
            }
276
        }
277
278
        else
            simdWidth = 1;
279
        if (supports64BitGlobalAtomics)
280
            compilationDefines["SUPPORTS_64_BIT_ATOMICS"] = "";
281
282
        if (supportsDoublePrecision)
            compilationDefines["SUPPORTS_DOUBLE_PRECISION"] = "";
283
        if (simdWidth >= 32)
Peter Eastman's avatar
Peter Eastman committed
284
            compilationDefines["SYNC_WARPS"] = "mem_fence(CLK_LOCAL_MEM_FENCE)";
285
286
        else
            compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)";
287
288
        vector<cl::Device> contextDevices;
        contextDevices.push_back(device);
Robert McGibbon's avatar
Robert McGibbon committed
289
        cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
290
291
292
293
294
295
296
297
        if (originalContext == NULL) {
            context = cl::Context(contextDevices, cprops, errorCallback);
            defaultQueue = cl::CommandQueue(context, device);
        }
        else {
            context = originalContext->context;
            defaultQueue = originalContext->defaultQueue;
        }
298
        currentQueue = defaultQueue;
Peter Eastman's avatar
Peter Eastman committed
299
300
        numAtoms = system.getNumParticles();
        paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
301
        numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
302
        numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
303
        if (useDoublePrecision) {
peastman's avatar
peastman committed
304
305
            posq.initialize<mm_double4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_double4>(*this, paddedNumAtoms, "velm");
306
307
            compilationDefines["USE_DOUBLE_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_double4";
308
309
310
            compilationDefines["make_real2"] = "make_double2";
            compilationDefines["make_real3"] = "make_double3";
            compilationDefines["make_real4"] = "make_double4";
311
            compilationDefines["convert_mixed4"] = "convert_double4";
312
313
314
            compilationDefines["make_mixed2"] = "make_double2";
            compilationDefines["make_mixed3"] = "make_double3";
            compilationDefines["make_mixed4"] = "make_double4";
315
316
        }
        else if (useMixedPrecision) {
peastman's avatar
peastman committed
317
318
319
            posq.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            posqCorrection.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_double4>(*this, paddedNumAtoms, "velm");
320
321
            compilationDefines["USE_MIXED_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_float4";
322
323
324
            compilationDefines["make_real2"] = "make_float2";
            compilationDefines["make_real3"] = "make_float3";
            compilationDefines["make_real4"] = "make_float4";
325
            compilationDefines["convert_mixed4"] = "convert_double4";
326
327
328
            compilationDefines["make_mixed2"] = "make_double2";
            compilationDefines["make_mixed3"] = "make_double3";
            compilationDefines["make_mixed4"] = "make_double4";
329
330
        }
        else {
peastman's avatar
peastman committed
331
332
            posq.initialize<mm_float4>(*this, paddedNumAtoms, "posq");
            velm.initialize<mm_float4>(*this, paddedNumAtoms, "velm");
333
            compilationDefines["convert_real4"] = "convert_float4";
334
335
336
            compilationDefines["make_real2"] = "make_float2";
            compilationDefines["make_real3"] = "make_float3";
            compilationDefines["make_real4"] = "make_float4";
337
            compilationDefines["convert_mixed4"] = "convert_float4";
338
339
340
            compilationDefines["make_mixed2"] = "make_float2";
            compilationDefines["make_mixed3"] = "make_float3";
            compilationDefines["make_mixed4"] = "make_float4";
341
        }
342
        longForceBuffer.initialize<cl_long>(*this, 3*paddedNumAtoms, "longForceBuffer");
343
        posCellOffsets.resize(paddedNumAtoms, mm_int4(0, 0, 0, 0));
344
345
346
347
348
        atomIndexDevice.initialize<cl_int>(*this, paddedNumAtoms, "atomIndexDevice");
        atomIndex.resize(paddedNumAtoms);
        for (int i = 0; i < paddedNumAtoms; ++i)
            atomIndex[i] = i;
        atomIndexDevice.upload(atomIndex);
349
350
351
352
353
    }
    catch (cl::Error err) {
        std::stringstream str;
        str<<"Error initializing context: "<<err.what()<<" ("<<err.err()<<")";
        throw OpenMMException(str.str());
354
    }
355
356
357

    // Create utility kernels that are used in multiple places.

Peter Eastman's avatar
Peter Eastman committed
358
    cl::Program utilities = createProgram(OpenCLKernelSources::utilities);
359
    clearBufferKernel = cl::Kernel(utilities, "clearBuffer");
360
361
362
    clearTwoBuffersKernel = cl::Kernel(utilities, "clearTwoBuffers");
    clearThreeBuffersKernel = cl::Kernel(utilities, "clearThreeBuffers");
    clearFourBuffersKernel = cl::Kernel(utilities, "clearFourBuffers");
363
364
    clearFiveBuffersKernel = cl::Kernel(utilities, "clearFiveBuffers");
    clearSixBuffersKernel = cl::Kernel(utilities, "clearSixBuffers");
365
    reduceReal4Kernel = cl::Kernel(utilities, "reduceReal4Buffer");
366
    reduceForcesKernel = cl::Kernel(utilities, "reduceForces");
Peter Eastman's avatar
Peter Eastman committed
367
    reduceEnergyKernel = cl::Kernel(utilities, "reduceEnergy");
368
    setChargesKernel = cl::Kernel(utilities, "setCharges");
369
370
371

    // Decide whether native_sqrt(), native_rsqrt(), and native_recip() are sufficiently accurate to use.

372
373
374
375
376
    if (!useDoublePrecision) {
        cl::Kernel accuracyKernel(utilities, "determineNativeAccuracy");
        OpenCLArray valuesArray(*this, 20, sizeof(mm_float8), "values");
        vector<mm_float8> values(valuesArray.getSize());
        float nextValue = 1e-4f;
peastman's avatar
peastman committed
377
378
        for (auto& val : values) {
            val.s0 = nextValue;
379
380
381
382
383
384
385
386
            nextValue *= (float) M_PI;
        }
        valuesArray.upload(values);
        accuracyKernel.setArg<cl::Buffer>(0, valuesArray.getDeviceBuffer());
        accuracyKernel.setArg<cl_int>(1, values.size());
        executeKernel(accuracyKernel, values.size());
        valuesArray.download(values);
        double maxSqrtError = 0.0, maxRsqrtError = 0.0, maxRecipError = 0.0, maxExpError = 0.0, maxLogError = 0.0;
peastman's avatar
peastman committed
387
388
        for (auto& val : values) {
            double v = val.s0;
389
            double correctSqrt = sqrt(v);
peastman's avatar
peastman committed
390
391
392
393
394
            maxSqrtError = max(maxSqrtError, fabs(correctSqrt-val.s1)/correctSqrt);
            maxRsqrtError = max(maxRsqrtError, fabs(1.0/correctSqrt-val.s2)*correctSqrt);
            maxRecipError = max(maxRecipError, fabs(1.0/v-val.s3)/val.s3);
            maxExpError = max(maxExpError, fabs(exp(v)-val.s4)/val.s4);
            maxLogError = max(maxLogError, fabs(log(v)-val.s5)/val.s5);
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        }
        compilationDefines["SQRT"] = (maxSqrtError < 1e-6) ? "native_sqrt" : "sqrt";
        compilationDefines["RSQRT"] = (maxRsqrtError < 1e-6) ? "native_rsqrt" : "rsqrt";
        compilationDefines["RECIP"] = (maxRecipError < 1e-6) ? "native_recip" : "1.0f/";
        compilationDefines["EXP"] = (maxExpError < 1e-6) ? "native_exp" : "exp";
        compilationDefines["LOG"] = (maxLogError < 1e-6) ? "native_log" : "log";
    }
    else {
        compilationDefines["SQRT"] = "sqrt";
        compilationDefines["RSQRT"] = "rsqrt";
        compilationDefines["RECIP"] = "1.0/";
        compilationDefines["EXP"] = "exp";
        compilationDefines["LOG"] = "log";
    }
409
410
411
412
413
414
415
    compilationDefines["POW"] = "pow";
    compilationDefines["COS"] = "cos";
    compilationDefines["SIN"] = "sin";
    compilationDefines["TAN"] = "tan";
    compilationDefines["ACOS"] = "acos";
    compilationDefines["ASIN"] = "asin";
    compilationDefines["ATAN"] = "atan";
416

417
    // Set defines for applying periodic boundary conditions.
418

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    Vec3 boxVectors[3];
    system.getDefaultPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
    boxIsTriclinic = (boxVectors[0][1] != 0.0 || boxVectors[0][2] != 0.0 ||
                      boxVectors[1][0] != 0.0 || boxVectors[1][2] != 0.0 ||
                      boxVectors[2][0] != 0.0 || boxVectors[2][1] != 0.0);
    if (boxIsTriclinic) {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "{"
            "real scale3 = floor(delta.z*invPeriodicBoxSize.z+0.5f); \\\n"
            "delta.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(delta.y*invPeriodicBoxSize.y+0.5f); \\\n"
            "delta.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(delta.x*invPeriodicBoxSize.x+0.5f); \\\n"
            "delta.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "{"
            "real scale3 = floor(pos.z*invPeriodicBoxSize.z); \\\n"
            "pos.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(pos.y*invPeriodicBoxSize.y); \\\n"
            "pos.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(pos.x*invPeriodicBoxSize.x); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "real scale3 = floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f); \\\n"
            "pos.x -= scale3*periodicBoxVecZ.x; \\\n"
            "pos.y -= scale3*periodicBoxVecZ.y; \\\n"
            "pos.z -= scale3*periodicBoxVecZ.z; \\\n"
            "real scale2 = floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f); \\\n"
            "pos.x -= scale2*periodicBoxVecY.x; \\\n"
            "pos.y -= scale2*periodicBoxVecY.y; \\\n"
            "real scale1 = floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
    }
    else {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "pos.xyz -= floor(pos.xyz*invPeriodicBoxSize.xyz)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "pos.x -= floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; \\\n"
            "pos.y -= floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y; \\\n"
            "pos.z -= floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;}";
    }

465
    // Create utilities objects.
466

467
468
    bonded = new OpenCLBondedUtilities(*this);
    nonbonded = new OpenCLNonbondedUtilities(*this);
Peter Eastman's avatar
Peter Eastman committed
469
    integration = new OpenCLIntegrationUtilities(*this, system);
470
    expression = new OpenCLExpressionUtilities(*this);
471
472
473
}

OpenCLContext::~OpenCLContext() {
peastman's avatar
peastman committed
474
475
476
477
478
479
480
481
    for (auto force : forces)
        delete force;
    for (auto listener : reorderListeners)
        delete listener;
    for (auto computation : preComputations)
        delete computation;
    for (auto computation : postComputations)
        delete computation;
482
483
    if (pinnedBuffer != NULL)
        delete pinnedBuffer;
484
485
    if (integration != NULL)
        delete integration;
486
487
    if (expression != NULL)
        delete expression;
Peter Eastman's avatar
Peter Eastman committed
488
489
    if (bonded != NULL)
        delete bonded;
490
491
    if (nonbonded != NULL)
        delete nonbonded;
492
493
}

494
void OpenCLContext::initialize() {
Peter Eastman's avatar
Peter Eastman committed
495
    bonded->initialize(system);
496
    numForceBuffers = std::max(numForceBuffers, (int) platformData.contexts.size());
Peter Eastman's avatar
Peter Eastman committed
497
    numForceBuffers = std::max(numForceBuffers, bonded->getNumForceBuffers());
498
    numForceBuffers = std::max(numForceBuffers, nonbonded->getNumForceBuffers());
499
    int energyBufferSize = max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers());
500
    if (useDoublePrecision) {
peastman's avatar
peastman committed
501
502
503
504
        forceBuffers.initialize<mm_double4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_double4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_double>(*this, energyBufferSize, "energyBuffer");
        energySum.initialize<cl_double>(*this, 1, "energySum");
505
    }
Peter Eastman's avatar
Peter Eastman committed
506
    else if (useMixedPrecision) {
peastman's avatar
peastman committed
507
508
509
510
        forceBuffers.initialize<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_float4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_double>(*this, energyBufferSize, "energyBuffer");
        energySum.initialize<cl_double>(*this, 1, "energySum");
Peter Eastman's avatar
Peter Eastman committed
511
512
    }
    else {
peastman's avatar
peastman committed
513
514
515
516
        forceBuffers.initialize<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force.initialize<mm_float4>(*this, &forceBuffers.getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer.initialize<cl_float>(*this, energyBufferSize, "energyBuffer");
        energySum.initialize<cl_float>(*this, 1, "energySum");
517
    }
518
519
520
521
    reduceForcesKernel.setArg<cl::Buffer>(0, longForceBuffer.getDeviceBuffer());
    reduceForcesKernel.setArg<cl::Buffer>(1, forceBuffers.getDeviceBuffer());
    reduceForcesKernel.setArg<cl_int>(2, paddedNumAtoms);
    reduceForcesKernel.setArg<cl_int>(3, numForceBuffers);
522
    addAutoclearBuffer(longForceBuffer);
peastman's avatar
peastman committed
523
524
    addAutoclearBuffer(forceBuffers);
    addAutoclearBuffer(energyBuffer);
525
526
527
    int numEnergyParamDerivs = energyParamDerivNames.size();
    if (numEnergyParamDerivs > 0) {
        if (useDoublePrecision || useMixedPrecision)
peastman's avatar
peastman committed
528
            energyParamDerivBuffer.initialize<cl_double>(*this, numEnergyParamDerivs*energyBufferSize, "energyParamDerivBuffer");
529
        else
peastman's avatar
peastman committed
530
531
            energyParamDerivBuffer.initialize<cl_float>(*this, numEnergyParamDerivs*energyBufferSize, "energyParamDerivBuffer");
        addAutoclearBuffer(energyParamDerivBuffer);
532
    }
533
534
535
    int bufferBytes = max(max(velm.getSize()*velm.getElementSize(),
            energyBufferSize*energyBuffer.getElementSize()),
            longForceBuffer.getSize()*longForceBuffer.getElementSize());
536
    pinnedBuffer = new cl::Buffer(context, CL_MEM_ALLOC_HOST_PTR, bufferBytes);
537
    pinnedMemory = currentQueue.enqueueMapBuffer(*pinnedBuffer, CL_TRUE, CL_MAP_READ | CL_MAP_WRITE, 0, bufferBytes);
538
539
540
541
542
543
544
    for (int i = 0; i < numAtoms; i++) {
        double mass = system.getParticleMass(i);
        if (useDoublePrecision || useMixedPrecision)
            ((mm_double4*) pinnedMemory)[i] = mm_double4(0.0, 0.0, 0.0, mass == 0.0 ? 0.0 : 1.0/mass);
        else
            ((mm_float4*) pinnedMemory)[i] = mm_float4(0.0f, 0.0f, 0.0f, mass == 0.0 ? 0.0f : (cl_float) (1.0/mass));
    }
peastman's avatar
peastman committed
545
    velm.upload(pinnedMemory);
546
    findMoleculeGroups();
547
    nonbonded->initialize(system);
548
549
}

550
551
void OpenCLContext::initializeContexts() {
    getPlatformData().initializeContexts(system);
552
553
}

554
555
556
557
558
void OpenCLContext::addForce(ComputeForceInfo* force) {
    ComputeContext::addForce(force);
    OpenCLForceInfo* clinfo = dynamic_cast<OpenCLForceInfo*>(force);
    if (clinfo != NULL)
        requestForceBuffers(clinfo->getRequiredForceBuffers());
559
560
}

561
562
void OpenCLContext::requestForceBuffers(int minBuffers) {
    numForceBuffers = std::max(numForceBuffers, minBuffers);
563
564
}

565
566
cl::Program OpenCLContext::createProgram(const string source, const char* optimizationFlags) {
    return createProgram(source, map<string, string>(), optimizationFlags);
567
568
}

569
cl::Program OpenCLContext::createProgram(const string source, const map<string, string>& defines, const char* optimizationFlags) {
Peter Eastman's avatar
Peter Eastman committed
570
    string options = (optimizationFlags == NULL ? defaultOptimizationOptions : string(optimizationFlags));
571
572
573
    stringstream src;
    if (!options.empty())
        src << "// Compilation Options: " << options << endl << endl;
peastman's avatar
peastman committed
574
    for (auto& pair : compilationDefines) {
575
576
577
578
579
580
581
        // Query defines to avoid duplicate variables
        if (defines.find(pair.first) == defines.end()) {
            src << "#define " << pair.first;
            if (!pair.second.empty())
                src << " " << pair.second;
            src << endl;
        }
582
583
584
    }
    if (!compilationDefines.empty())
        src << endl;
585
586
587
588
589
    if (supportsDoublePrecision)
        src << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
    if (useDoublePrecision) {
        src << "typedef double real;\n";
        src << "typedef double2 real2;\n";
590
        src << "typedef double3 real3;\n";
591
592
593
594
595
        src << "typedef double4 real4;\n";
    }
    else {
        src << "typedef float real;\n";
        src << "typedef float2 real2;\n";
596
        src << "typedef float3 real3;\n";
597
598
599
600
601
        src << "typedef float4 real4;\n";
    }
    if (useDoublePrecision || useMixedPrecision) {
        src << "typedef double mixed;\n";
        src << "typedef double2 mixed2;\n";
602
        src << "typedef double3 mixed3;\n";
603
604
605
606
607
        src << "typedef double4 mixed4;\n";
    }
    else {
        src << "typedef float mixed;\n";
        src << "typedef float2 mixed2;\n";
608
        src << "typedef float3 mixed3;\n";
609
610
        src << "typedef float4 mixed4;\n";
    }
611
    src << OpenCLKernelSources::common << endl;
peastman's avatar
peastman committed
612
613
614
615
    for (auto& pair : defines) {
        src << "#define " << pair.first;
        if (!pair.second.empty())
            src << " " << pair.second;
616
617
618
619
620
        src << endl;
    }
    if (!defines.empty())
        src << endl;
    src << source << endl;
621
    cl::Program::Sources sources({src.str()});
622
623
    cl::Program program(context, sources);
    try {
624
        program.build(vector<cl::Device>(1, device), options.c_str());
625
626
627
628
629
630
    } catch (cl::Error err) {
        throw OpenMMException("Error compiling kernel: "+program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(device));
    }
    return program;
}

631
632
633
634
635
636
637
638
639
640
641
642
cl::CommandQueue& OpenCLContext::getQueue() {
    return currentQueue;
}

void OpenCLContext::setQueue(cl::CommandQueue& queue) {
    currentQueue = queue;
}

void OpenCLContext::restoreDefaultQueue() {
    currentQueue = defaultQueue;
}

643
644
645
646
647
648
OpenCLArray* OpenCLContext::createArray() {
    return new OpenCLArray();
}

ComputeEvent OpenCLContext::createEvent() {
    return shared_ptr<ComputeEventImpl>(new OpenCLEvent(*this));
649
650
}

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
ComputeProgram OpenCLContext::compileProgram(const std::string source, const std::map<std::string, std::string>& defines) {
    cl::Program program = createProgram(source, defines);
    return shared_ptr<ComputeProgramImpl>(new OpenCLProgram(*this, program));
}

OpenCLArray& OpenCLContext::unwrap(ArrayInterface& array) const {
    OpenCLArray* clarray;
    ComputeArray* wrapper = dynamic_cast<ComputeArray*>(&array);
    if (wrapper != NULL)
        clarray = dynamic_cast<OpenCLArray*>(&wrapper->getArray());
    else
        clarray = dynamic_cast<OpenCLArray*>(&array);
    if (clarray == NULL)
        throw OpenMMException("Array argument is not an OpenCLArray");
    return *clarray;
666
667
}

668
669
670
671
void OpenCLContext::executeKernel(cl::Kernel& kernel, int workUnits, int blockSize) {
    if (blockSize == -1)
        blockSize = ThreadBlockSize;
    int size = std::min((workUnits+blockSize-1)/blockSize, numThreadBlocks)*blockSize;
672
    try {
673
        currentQueue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(size), cl::NDRange(blockSize));
674
675
676
    }
    catch (cl::Error err) {
        stringstream str;
677
        str<<"Error invoking kernel "<<kernel.getInfo<CL_KERNEL_FUNCTION_NAME>()<<": "<<err.what()<<" ("<<err.err()<<")";
678
679
680
681
        throw OpenMMException(str.str());
    }
}

682
683
void OpenCLContext::clearBuffer(ArrayInterface& array) {
    clearBuffer(unwrap(array).getDeviceBuffer(), array.getSize()*array.getElementSize());
684
685
}

686
void OpenCLContext::clearBuffer(cl::Memory& memory, int size) {
687
    int words = size/4;
688
    clearBufferKernel.setArg<cl::Memory>(0, memory);
689
690
691
692
    clearBufferKernel.setArg<cl_int>(1, words);
    executeKernel(clearBufferKernel, words, 128);
}

693
694
void OpenCLContext::addAutoclearBuffer(ArrayInterface& array) {
    addAutoclearBuffer(unwrap(array).getDeviceBuffer(), array.getSize()*array.getElementSize());
695
696
}

697
698
void OpenCLContext::addAutoclearBuffer(cl::Memory& memory, int size) {
    autoclearBuffers.push_back(&memory);
699
    autoclearBufferSizes.push_back(size/4);
700
701
702
703
704
}

void OpenCLContext::clearAutoclearBuffers() {
    int base = 0;
    int total = autoclearBufferSizes.size();
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
    while (total-base >= 6) {
        clearSixBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearSixBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearSixBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearSixBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearSixBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearSixBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearSixBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearSixBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearSixBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearSixBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        clearSixBuffersKernel.setArg<cl::Memory>(10, *autoclearBuffers[base+5]);
        clearSixBuffersKernel.setArg<cl_int>(11, autoclearBufferSizes[base+5]);
        executeKernel(clearSixBuffersKernel, max(max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), autoclearBufferSizes[base+5]), 128);
        base += 6;
    }
    if (total-base == 5) {
        clearFiveBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFiveBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFiveBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFiveBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFiveBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFiveBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFiveBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFiveBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearFiveBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearFiveBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        executeKernel(clearFiveBuffersKernel, max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), 128);
    }
    else if (total-base == 4) {
735
736
737
738
739
740
741
742
        clearFourBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFourBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFourBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFourBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFourBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFourBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFourBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFourBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
743
        executeKernel(clearFourBuffersKernel, max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), 128);
744
    }
745
    else if (total-base == 3) {
746
747
748
749
750
751
        clearThreeBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearThreeBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearThreeBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearThreeBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearThreeBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearThreeBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
752
        executeKernel(clearThreeBuffersKernel, max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), 128);
753
754
755
756
757
758
    }
    else if (total-base == 2) {
        clearTwoBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearTwoBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearTwoBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearTwoBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
759
        executeKernel(clearTwoBuffersKernel, max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), 128);
760
761
    }
    else if (total-base == 1) {
762
        clearBuffer(*autoclearBuffers[base], autoclearBufferSizes[base]*4);
763
764
765
    }
}

766
void OpenCLContext::reduceForces() {
767
    executeKernel(reduceForcesKernel, paddedNumAtoms, 128);
768
769
}

770
void OpenCLContext::reduceBuffer(OpenCLArray& array, int numBuffers) {
771
    int bufferSize = array.getSize()/numBuffers;
772
773
774
775
    reduceReal4Kernel.setArg<cl::Buffer>(0, array.getDeviceBuffer());
    reduceReal4Kernel.setArg<cl_int>(1, bufferSize);
    reduceReal4Kernel.setArg<cl_int>(2, numBuffers);
    executeKernel(reduceReal4Kernel, bufferSize, 128);
776
}
777

Peter Eastman's avatar
Peter Eastman committed
778
779
780
781
double OpenCLContext::reduceEnergy() {
    int workGroupSize  = device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>();
    if (workGroupSize > 512)
        workGroupSize = 512;
peastman's avatar
peastman committed
782
783
784
    reduceEnergyKernel.setArg<cl::Buffer>(0, energyBuffer.getDeviceBuffer());
    reduceEnergyKernel.setArg<cl::Buffer>(1, energySum.getDeviceBuffer());
    reduceEnergyKernel.setArg<cl_int>(2, energyBuffer.getSize());
Peter Eastman's avatar
Peter Eastman committed
785
    reduceEnergyKernel.setArg<cl_int>(3, workGroupSize);
peastman's avatar
peastman committed
786
    reduceEnergyKernel.setArg(4, workGroupSize*energyBuffer.getElementSize(), NULL);
Peter Eastman's avatar
Peter Eastman committed
787
788
789
    executeKernel(reduceEnergyKernel, workGroupSize, workGroupSize);
    if (getUseDoublePrecision() || getUseMixedPrecision()) {
        double energy;
peastman's avatar
peastman committed
790
        energySum.download(&energy);
Peter Eastman's avatar
Peter Eastman committed
791
792
793
794
        return energy;
    }
    else {
        float energy;
peastman's avatar
peastman committed
795
        energySum.download(&energy);
Peter Eastman's avatar
Peter Eastman committed
796
797
798
799
        return energy;
    }
}

800
void OpenCLContext::setCharges(const vector<double>& charges) {
peastman's avatar
peastman committed
801
802
    if (!chargeBuffer.isInitialized())
        chargeBuffer.initialize(*this, numAtoms, useDoublePrecision ? sizeof(double) : sizeof(float), "chargeBuffer");
peastman's avatar
peastman committed
803
804
805
    vector<double> c(numAtoms);
    for (int i = 0; i < numAtoms; i++)
        c[i] = charges[i];
806
    chargeBuffer.upload(c, true);
peastman's avatar
peastman committed
807
808
809
    setChargesKernel.setArg<cl::Buffer>(0, chargeBuffer.getDeviceBuffer());
    setChargesKernel.setArg<cl::Buffer>(1, posq.getDeviceBuffer());
    setChargesKernel.setArg<cl::Buffer>(2, atomIndexDevice.getDeviceBuffer());
810
811
812
813
    setChargesKernel.setArg<cl_int>(3, numAtoms);
    executeKernel(setChargesKernel, numAtoms);
}

814
815
816
817
818
819
bool OpenCLContext::requestPosqCharges() {
    bool allow = !hasAssignedPosqCharges;
    hasAssignedPosqCharges = true;
    return allow;
}

820
821
822
823
824
825
826
827
828
void OpenCLContext::addEnergyParameterDerivative(const string& param) {
    // See if this parameter has already been registered.
    
    for (int i = 0; i < energyParamDerivNames.size(); i++)
        if (param == energyParamDerivNames[i])
            return;
    energyParamDerivNames.push_back(param);
}

829
830
void OpenCLContext::flushQueue() {
    getQueue().flush();
831
}