OpenCLContext.cpp 59.4 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2016 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
28
29
30
#ifdef WIN32
  #define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include <cmath>
31
32
#include "OpenCLContext.h"
#include "OpenCLArray.h"
Peter Eastman's avatar
Peter Eastman committed
33
#include "OpenCLBondedUtilities.h"
34
#include "OpenCLForceInfo.h"
35
#include "OpenCLIntegrationUtilities.h"
36
#include "OpenCLKernelSources.h"
37
#include "OpenCLNonbondedUtilities.h"
38
#include "hilbert.h"
39
#include "openmm/Platform.h"
40
#include "openmm/System.h"
41
#include "openmm/VirtualSite.h"
42
#include "openmm/internal/ContextImpl.h"
Peter Eastman's avatar
Peter Eastman committed
43
#include <algorithm>
44
45
#include <fstream>
#include <iostream>
46
#include <set>
47
#include <sstream>
48
#include <typeinfo>
49
50

using namespace OpenMM;
51
using namespace std;
52

53
54
55
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV 0x4000
#endif
56
57
58
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV 0x4001
#endif
59

60
61
62
const int OpenCLContext::ThreadBlockSize = 64;
const int OpenCLContext::TileSize = 32;

63
static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_info, size_t cb, void* user_data) {
64
65
66
    string skip = "OpenCL Build Warning : Compiler build log:";
    if (strncmp(errinfo, skip.c_str(), skip.length()) == 0)
        return; // OS X Lion insists on calling this for every build warning, even though they aren't errors.
67
68
69
    std::cerr << "OpenCL internal error: " << errinfo << std::endl;
}

70
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, const string& precision, OpenCLPlatform::PlatformData& platformData) :
71
        system(system), time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), stepsSinceReorder(99999), atomsWereReordered(false), posq(NULL),
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
72
        posqCorrection(NULL), velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndexDevice(NULL), integration(NULL),
73
        expression(NULL), bonded(NULL), nonbonded(NULL), thread(NULL) {
74
75
76
77
78
79
80
81
82
83
84
85
86
    if (precision == "single") {
        useDoublePrecision = false;
        useMixedPrecision = false;
    }
    else if (precision == "mixed") {
        useDoublePrecision = false;
        useMixedPrecision = true;
    }
    else if (precision == "double") {
        useDoublePrecision = true;
        useMixedPrecision = false;
    }
    else
87
        throw OpenMMException("Illegal value for Precision: "+precision);
88
    try {
89
        contextIndex = platformData.contexts.size();
90
91
        std::vector<cl::Platform> platforms;
        cl::Platform::get(&platforms);
92
93
        if (platformIndex < -1 || platformIndex >= (int) platforms.size())
            throw OpenMMException("Illegal value for OpenCLPlatformIndex: "+intToString(platformIndex));
Robert McGibbon's avatar
Robert McGibbon committed
94
        const int minThreadBlockSize = 32;
95

Robert McGibbon's avatar
Robert McGibbon committed
96
97
98
        int bestSpeed = -1;
        int bestDevice = -1;
        int bestPlatform = -1;
Robert McGibbon's avatar
Robert McGibbon committed
99
        for (int j = 0; j < platforms.size(); j++) {
100
101
            // If they supplied a valid platformIndex, we only look through that platform
            if (j != platformIndex && platformIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
102
103
104
105
106
                continue;

            string platformVendor = platforms[j].getInfo<CL_PLATFORM_VENDOR>();
            vector<cl::Device> devices;
            platforms[j].getDevices(CL_DEVICE_TYPE_ALL, &devices);
107
            if (deviceIndex < -1 || deviceIndex >= (int) devices.size())
108
                throw OpenMMException("Illegal value for DeviceIndex: "+intToString(deviceIndex));
Robert McGibbon's avatar
Robert McGibbon committed
109

110
            for (int i = 0; i < (int) devices.size(); i++) {
111
112
                // If they supplied a valid deviceIndex, we only look through that one
                if (i != deviceIndex && deviceIndex != -1)
Robert McGibbon's avatar
Robert McGibbon committed
113
                    continue;
114
115
                if (platformVendor == "Apple" && (devices[i].getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU))
                    continue; // The CPU device on OS X won't work correctly.
116
117
118
119
120
                if (useMixedPrecision || useDoublePrecision) {
                    bool supportsDouble = (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
                    if (!supportsDouble)
                        continue; // This device does not support double precision.
                }
121
                int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0];
122
123
124
125
126
                int processingElementsPerComputeUnit = 8;
                if (devices[i].getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                    processingElementsPerComputeUnit = 1;
                }
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
127
128
129
130
                    cl_uint computeCapabilityMajor;
                    clGetDeviceInfo(devices[i](), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
                    processingElementsPerComputeUnit = (computeCapabilityMajor < 2 ? 8 : 32);
                }
131
132
133
134
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime (it may be an older runtime,
                    // or the CPU device) so still have to check for errors.
                    try {
135
#ifdef CL_DEVICE_SIMD_WIDTH_AMD
136
137
138
139
140
141
                        processingElementsPerComputeUnit =
                            // AMD GPUs either have a single VLIW SIMD or multiple scalar SIMDs.
                            // The SIMD width is the number of threads the SIMD executes per cycle.
                            // This will be less than the wavefront width since it takes several
                            // cycles to execute the full wavefront.
                            // The SIMD instruction width is the VLIW instruction width (or 1 for scalar),
142
                            // this is the number of ALUs that can be executing per instruction per thread.
143
144
145
146
147
148
                            devices[i].getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_WIDTH_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_INSTRUCTION_WIDTH_AMD>();
                        // Just in case any of the queries return 0.
                        if (processingElementsPerComputeUnit <= 0)
                            processingElementsPerComputeUnit = 1;
149
#endif
150
151
152
153
154
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the queries so use default.
                    }
                }
155
                int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*processingElementsPerComputeUnit*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
156
                if (maxSize >= minThreadBlockSize && speed > bestSpeed) {
Robert McGibbon's avatar
Robert McGibbon committed
157
                    bestDevice = i;
158
                    bestSpeed = speed;
Robert McGibbon's avatar
Robert McGibbon committed
159
                    bestPlatform = j;
160
                }
161
            }
162
        }
Robert McGibbon's avatar
Robert McGibbon committed
163
164
165
166
167

        if (bestPlatform == -1)
            throw OpenMMException("No compatible OpenCL platform is available");

        if (bestDevice == -1)
168
            throw OpenMMException("No compatible OpenCL device is available");
Robert McGibbon's avatar
Robert McGibbon committed
169
170
171

        vector<cl::Device> devices;
        platforms[bestPlatform].getDevices(CL_DEVICE_TYPE_ALL, &devices);
Robert McGibbon's avatar
Robert McGibbon committed
172
        string platformVendor = platforms[bestPlatform].getInfo<CL_PLATFORM_VENDOR>();
Robert McGibbon's avatar
Robert McGibbon committed
173
174
175
        device = devices[bestDevice];

        this->deviceIndex = bestDevice;
Robert McGibbon's avatar
Robert McGibbon committed
176
        this->platformIndex = bestPlatform;
177
        if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize)
178
            throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
179
        compilationDefines["WORK_GROUP_SIZE"] = intToString(ThreadBlockSize);
Peter Eastman's avatar
Peter Eastman committed
180
        if (platformVendor.size() >= 5 && platformVendor.substr(0, 5) == "Intel")
181
182
            defaultOptimizationOptions = "";
        else
183
            defaultOptimizationOptions = "-cl-mad-enable -cl-no-signed-zeros";
184
        supports64BitGlobalAtomics = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_int64_base_atomics") != string::npos);
185
        supportsDoublePrecision = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
186
187
        if ((useDoublePrecision || useMixedPrecision) && !supportsDoublePrecision)
            throw OpenMMException("This device does not support double precision");
188
        string vendor = device.getInfo<CL_DEVICE_VENDOR>();
189
        int numThreadBlocksPerComputeUnit = 6;
190
        if (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA") {
191
            compilationDefines["WARPS_ARE_ATOMIC"] = "";
192
            simdWidth = 32;
193
194
            if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
                // Compute level 1.2 and later Nvidia GPUs support 64 bit atomics, even though they don't list the
195
196
                // proper extension as supported.  We only use them on compute level 2.0 or later, since they're very
                // slow on earlier GPUs.
197

198
                cl_uint computeCapabilityMajor;
199
                clGetDeviceInfo(device(), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
200
                if (computeCapabilityMajor > 1)
201
                    supports64BitGlobalAtomics = true;
202
203
204
205
206
207
208
                if (computeCapabilityMajor == 5) {
                    // Workaround for a bug in Maxwell on CUDA 6.x.

                    string platformVersion = platforms[bestPlatform].getInfo<CL_PLATFORM_VERSION>();
                    if (platformVersion.find("CUDA 6") != string::npos)
                        supports64BitGlobalAtomics = false;
                }
209
            }
210
        }
211
        else if (vendor.size() >= 28 && vendor.substr(0, 28) == "Advanced Micro Devices, Inc.") {
212
213
214
215
216
217
218
219
220
221
222
223
224
            if (device.getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                /// \todo Is 6 a good value for the OpenCL CPU device?
                // numThreadBlocksPerComputeUnit = ?;
                simdWidth = 1;
            }
            else {
                bool amdPostSdk2_4 = false;
                // Default to 1 which will use the default kernels.
                simdWidth = 1;
                if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime so still have to
                    // check for errors.
                    try {
Peter Eastman's avatar
Peter Eastman committed
225
#ifdef CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD
226
227
228
                        // Must catch cl:Error as will fail if runtime does not support queries.

                        cl_uint simdPerComputeUnit = device.getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>();
229
230
                        simdWidth = device.getInfo<CL_DEVICE_WAVEFRONT_WIDTH_AMD>();

231
232
233
234
235
236
237
238
                        // If the GPU has multiple SIMDs per compute unit then it is uses the scalar instruction
                        // set instead of the VLIW instruction set. It therefore needs more thread blocks per
                        // compute unit to hide memory latency.
                        if (simdPerComputeUnit > 1)
                            numThreadBlocksPerComputeUnit = 4 * simdPerComputeUnit;

                        // If the queries are supported then must be newer than SDK 2.4.
                        amdPostSdk2_4 = true;
Peter Eastman's avatar
Peter Eastman committed
239
#endif
240
241
242
243
244
245
246
247
248
249
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the query so is unlikely to be the newer scalar GPU.
                        // Stay with the default simdWidth and numThreadBlocksPerComputeUnit.
                    }
                }
                // AMD APP SDK 2.4 has a performance problem with atomics. Enable the work around. This is fixed after SDK 2.4.
                if (!amdPostSdk2_4)
                    compilationDefines["AMD_ATOMIC_WORK_AROUND"] = "";
            }
250
        }
251
252
        else
            simdWidth = 1;
253
        if (supports64BitGlobalAtomics)
254
            compilationDefines["SUPPORTS_64_BIT_ATOMICS"] = "";
255
256
        if (supportsDoublePrecision)
            compilationDefines["SUPPORTS_DOUBLE_PRECISION"] = "";
257
258
259
260
        if (simdWidth >= 32)
            compilationDefines["SYNC_WARPS"] = "";
        else
            compilationDefines["SYNC_WARPS"] = "barrier(CLK_LOCAL_MEM_FENCE)";
261
262
        vector<cl::Device> contextDevices;
        contextDevices.push_back(device);
Robert McGibbon's avatar
Robert McGibbon committed
263
        cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[bestPlatform](), 0};
264
        context = cl::Context(contextDevices, cprops, errorCallback);
265
266
        defaultQueue = cl::CommandQueue(context, device);
        currentQueue = defaultQueue;
Peter Eastman's avatar
Peter Eastman committed
267
268
        numAtoms = system.getNumParticles();
        paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
269
        numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
270
        numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        if (useDoublePrecision) {
            posq = OpenCLArray::create<mm_double4>(*this, paddedNumAtoms, "posq");
            velm = OpenCLArray::create<mm_double4>(*this, paddedNumAtoms, "velm");
            compilationDefines["USE_DOUBLE_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_double4";
            compilationDefines["convert_mixed4"] = "convert_double4";
        }
        else if (useMixedPrecision) {
            posq = OpenCLArray::create<mm_float4>(*this, paddedNumAtoms, "posq");
            posqCorrection = OpenCLArray::create<mm_float4>(*this, paddedNumAtoms, "posq");
            velm = OpenCLArray::create<mm_double4>(*this, paddedNumAtoms, "velm");
            compilationDefines["USE_MIXED_PRECISION"] = "1";
            compilationDefines["convert_real4"] = "convert_float4";
            compilationDefines["convert_mixed4"] = "convert_double4";
        }
        else {
            posq = OpenCLArray::create<mm_float4>(*this, paddedNumAtoms, "posq");
            velm = OpenCLArray::create<mm_float4>(*this, paddedNumAtoms, "velm");
            compilationDefines["convert_real4"] = "convert_float4";
            compilationDefines["convert_mixed4"] = "convert_float4";
        }
292
        posCellOffsets.resize(paddedNumAtoms, mm_int4(0, 0, 0, 0));
293
294
295
296
297
    }
    catch (cl::Error err) {
        std::stringstream str;
        str<<"Error initializing context: "<<err.what()<<" ("<<err.err()<<")";
        throw OpenMMException(str.str());
298
    }
299
300
301

    // Create utility kernels that are used in multiple places.

Peter Eastman's avatar
Peter Eastman committed
302
    cl::Program utilities = createProgram(OpenCLKernelSources::utilities);
303
    clearBufferKernel = cl::Kernel(utilities, "clearBuffer");
304
305
306
    clearTwoBuffersKernel = cl::Kernel(utilities, "clearTwoBuffers");
    clearThreeBuffersKernel = cl::Kernel(utilities, "clearThreeBuffers");
    clearFourBuffersKernel = cl::Kernel(utilities, "clearFourBuffers");
307
308
    clearFiveBuffersKernel = cl::Kernel(utilities, "clearFiveBuffers");
    clearSixBuffersKernel = cl::Kernel(utilities, "clearSixBuffers");
309
    reduceReal4Kernel = cl::Kernel(utilities, "reduceReal4Buffer");
310
311
    if (supports64BitGlobalAtomics)
        reduceForcesKernel = cl::Kernel(utilities, "reduceForces");
312
313
314

    // Decide whether native_sqrt(), native_rsqrt(), and native_recip() are sufficiently accurate to use.

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    if (!useDoublePrecision) {
        cl::Kernel accuracyKernel(utilities, "determineNativeAccuracy");
        OpenCLArray valuesArray(*this, 20, sizeof(mm_float8), "values");
        vector<mm_float8> values(valuesArray.getSize());
        float nextValue = 1e-4f;
        for (int i = 0; i < (int) values.size(); ++i) {
            values[i].s0 = nextValue;
            nextValue *= (float) M_PI;
        }
        valuesArray.upload(values);
        accuracyKernel.setArg<cl::Buffer>(0, valuesArray.getDeviceBuffer());
        accuracyKernel.setArg<cl_int>(1, values.size());
        executeKernel(accuracyKernel, values.size());
        valuesArray.download(values);
        double maxSqrtError = 0.0, maxRsqrtError = 0.0, maxRecipError = 0.0, maxExpError = 0.0, maxLogError = 0.0;
        for (int i = 0; i < (int) values.size(); ++i) {
            double v = values[i].s0;
            double correctSqrt = sqrt(v);
            maxSqrtError = max(maxSqrtError, fabs(correctSqrt-values[i].s1)/correctSqrt);
            maxRsqrtError = max(maxRsqrtError, fabs(1.0/correctSqrt-values[i].s2)*correctSqrt);
            maxRecipError = max(maxRecipError, fabs(1.0/v-values[i].s3)/values[i].s3);
            maxExpError = max(maxExpError, fabs(exp(v)-values[i].s4)/values[i].s4);
            maxLogError = max(maxLogError, fabs(log(v)-values[i].s5)/values[i].s5);
        }
        compilationDefines["SQRT"] = (maxSqrtError < 1e-6) ? "native_sqrt" : "sqrt";
        compilationDefines["RSQRT"] = (maxRsqrtError < 1e-6) ? "native_rsqrt" : "rsqrt";
        compilationDefines["RECIP"] = (maxRecipError < 1e-6) ? "native_recip" : "1.0f/";
        compilationDefines["EXP"] = (maxExpError < 1e-6) ? "native_exp" : "exp";
        compilationDefines["LOG"] = (maxLogError < 1e-6) ? "native_log" : "log";
    }
    else {
        compilationDefines["SQRT"] = "sqrt";
        compilationDefines["RSQRT"] = "rsqrt";
        compilationDefines["RECIP"] = "1.0/";
        compilationDefines["EXP"] = "exp";
        compilationDefines["LOG"] = "log";
    }
352

353
    // Set defines for applying periodic boundary conditions.
354

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    Vec3 boxVectors[3];
    system.getDefaultPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
    boxIsTriclinic = (boxVectors[0][1] != 0.0 || boxVectors[0][2] != 0.0 ||
                      boxVectors[1][0] != 0.0 || boxVectors[1][2] != 0.0 ||
                      boxVectors[2][0] != 0.0 || boxVectors[2][1] != 0.0);
    if (boxIsTriclinic) {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "{"
            "real scale3 = floor(delta.z*invPeriodicBoxSize.z+0.5f); \\\n"
            "delta.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(delta.y*invPeriodicBoxSize.y+0.5f); \\\n"
            "delta.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(delta.x*invPeriodicBoxSize.x+0.5f); \\\n"
            "delta.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "{"
            "real scale3 = floor(pos.z*invPeriodicBoxSize.z); \\\n"
            "pos.xyz -= scale3*periodicBoxVecZ.xyz; \\\n"
            "real scale2 = floor(pos.y*invPeriodicBoxSize.y); \\\n"
            "pos.xy -= scale2*periodicBoxVecY.xy; \\\n"
            "real scale1 = floor(pos.x*invPeriodicBoxSize.x); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "real scale3 = floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f); \\\n"
            "pos.x -= scale3*periodicBoxVecZ.x; \\\n"
            "pos.y -= scale3*periodicBoxVecZ.y; \\\n"
            "pos.z -= scale3*periodicBoxVecZ.z; \\\n"
            "real scale2 = floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f); \\\n"
            "pos.x -= scale2*periodicBoxVecY.x; \\\n"
            "pos.y -= scale2*periodicBoxVecY.y; \\\n"
            "real scale1 = floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f); \\\n"
            "pos.x -= scale1*periodicBoxVecX.x;}";
    }
    else {
        compilationDefines["APPLY_PERIODIC_TO_DELTA(delta)"] =
            "delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS(pos)"] =
            "pos.xyz -= floor(pos.xyz*invPeriodicBoxSize.xyz)*periodicBoxSize.xyz;";
        compilationDefines["APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)"] =
            "{"
            "pos.x -= floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; \\\n"
            "pos.y -= floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y; \\\n"
            "pos.z -= floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;}";
    }

401
    // Create the work thread used for parallelization when running on multiple devices.
402

403
    thread = new WorkThread();
404

405
    // Create utilities objects.
406

407
408
    bonded = new OpenCLBondedUtilities(*this);
    nonbonded = new OpenCLNonbondedUtilities(*this);
Peter Eastman's avatar
Peter Eastman committed
409
    integration = new OpenCLIntegrationUtilities(*this, system);
410
    expression = new OpenCLExpressionUtilities(*this);
411
412
413
}

OpenCLContext::~OpenCLContext() {
414
415
    for (int i = 0; i < (int) forces.size(); i++)
        delete forces[i];
416
417
    for (int i = 0; i < (int) reorderListeners.size(); i++)
        delete reorderListeners[i];
418
419
420
421
    for (int i = 0; i < (int) preComputations.size(); i++)
        delete preComputations[i];
    for (int i = 0; i < (int) postComputations.size(); i++)
        delete postComputations[i];
422
423
    if (pinnedBuffer != NULL)
        delete pinnedBuffer;
424
425
    if (posq != NULL)
        delete posq;
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
426
427
    if (posqCorrection != NULL)
        delete posqCorrection;
428
429
430
431
432
433
    if (velm != NULL)
        delete velm;
    if (force != NULL)
        delete force;
    if (forceBuffers != NULL)
        delete forceBuffers;
434
435
    if (longForceBuffer != NULL)
        delete longForceBuffer;
436
437
    if (energyBuffer != NULL)
        delete energyBuffer;
438
439
    if (atomIndexDevice != NULL)
        delete atomIndexDevice;
440
441
    if (integration != NULL)
        delete integration;
442
443
    if (expression != NULL)
        delete expression;
Peter Eastman's avatar
Peter Eastman committed
444
445
    if (bonded != NULL)
        delete bonded;
446
447
    if (nonbonded != NULL)
        delete nonbonded;
448
449
    if (thread != NULL)
        delete thread;
450
451
}

452
void OpenCLContext::initialize() {
Peter Eastman's avatar
Peter Eastman committed
453
    bonded->initialize(system);
454
    numForceBuffers = platformData.contexts.size();
Peter Eastman's avatar
Peter Eastman committed
455
    numForceBuffers = std::max(numForceBuffers, bonded->getNumForceBuffers());
456
457
    for (int i = 0; i < (int) forces.size(); i++)
        numForceBuffers = std::max(numForceBuffers, forces[i]->getRequiredForceBuffers());
458
459
460
461
462
463
464
465
    if (useDoublePrecision) {
        forceBuffers = OpenCLArray::create<mm_double4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force = OpenCLArray::create<mm_double4>(*this, &forceBuffers->getDeviceBuffer(), paddedNumAtoms, "force");
        energyBuffer = OpenCLArray::create<cl_double>(*this, max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers()), "energyBuffer");
    }
    else {
        forceBuffers = OpenCLArray::create<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers");
        force = OpenCLArray::create<mm_float4>(*this, &forceBuffers->getDeviceBuffer(), paddedNumAtoms, "force");
466
        energyBuffer = OpenCLArray::create<cl_double>(*this, max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers()), "energyBuffer");
467
    }
468
    if (supports64BitGlobalAtomics) {
469
        longForceBuffer = OpenCLArray::create<cl_long>(*this, 3*paddedNumAtoms, "longForceBuffer");
470
471
472
473
        reduceForcesKernel.setArg<cl::Buffer>(0, longForceBuffer->getDeviceBuffer());
        reduceForcesKernel.setArg<cl::Buffer>(1, forceBuffers->getDeviceBuffer());
        reduceForcesKernel.setArg<cl_int>(2, paddedNumAtoms);
        reduceForcesKernel.setArg<cl_int>(3, numForceBuffers);
474
        addAutoclearBuffer(*longForceBuffer);
475
    }
476
477
    addAutoclearBuffer(*forceBuffers);
    addAutoclearBuffer(*energyBuffer);
478
    int bufferBytes = max(velm->getSize()*velm->getElementSize(), energyBuffer->getSize()*energyBuffer->getElementSize());
479
    pinnedBuffer = new cl::Buffer(context, CL_MEM_ALLOC_HOST_PTR, bufferBytes);
480
    pinnedMemory = currentQueue.enqueueMapBuffer(*pinnedBuffer, CL_TRUE, CL_MAP_READ | CL_MAP_WRITE, 0, bufferBytes);
481
482
483
484
485
486
487
488
    for (int i = 0; i < numAtoms; i++) {
        double mass = system.getParticleMass(i);
        if (useDoublePrecision || useMixedPrecision)
            ((mm_double4*) pinnedMemory)[i] = mm_double4(0.0, 0.0, 0.0, mass == 0.0 ? 0.0 : 1.0/mass);
        else
            ((mm_float4*) pinnedMemory)[i] = mm_float4(0.0f, 0.0f, 0.0f, mass == 0.0 ? 0.0f : (cl_float) (1.0/mass));
    }
    velm->upload(pinnedMemory);
489
490
    atomIndexDevice = OpenCLArray::create<cl_int>(*this, paddedNumAtoms, "atomIndexDevice");
    atomIndex.resize(paddedNumAtoms);
491
    for (int i = 0; i < paddedNumAtoms; ++i)
492
493
        atomIndex[i] = i;
    atomIndexDevice->upload(atomIndex);
494
    findMoleculeGroups();
495
    nonbonded->initialize(system);
496
497
498
499
500
501
}

void OpenCLContext::addForce(OpenCLForceInfo* force) {
    forces.push_back(force);
}

502
string OpenCLContext::replaceStrings(const string& input, const std::map<std::string, std::string>& replacements) const {
503
504
505
506
507
508
509
510
511
512
    static set<char> symbolChars;
    if (symbolChars.size() == 0) {
        symbolChars.insert('_');
        for (char c = 'a'; c <= 'z'; c++)
            symbolChars.insert(c);
        for (char c = 'A'; c <= 'Z'; c++)
            symbolChars.insert(c);
        for (char c = '0'; c <= '9'; c++)
            symbolChars.insert(c);
    }
513
    string result = input;
514
    for (map<string, string>::const_iterator iter = replacements.begin(); iter != replacements.end(); iter++) {
515
516
        int index = 0;
        int size = iter->first.size();
517
        do {
518
519
520
521
            index = result.find(iter->first, index);
            if (index != result.npos) {
                if ((index == 0 || symbolChars.find(result[index-1]) == symbolChars.end()) && (index == result.size()-size || symbolChars.find(result[index+size]) == symbolChars.end())) {
                    // We have found a complete symbol, not part of a longer symbol.
522

523
524
525
526
527
528
                    result.replace(index, size, iter->second);
                    index += iter->second.size();
                }
                else
                    index++;
            }
529
        } while (index != result.npos);
530
    }
531
    return result;
532
533
}

534
535
cl::Program OpenCLContext::createProgram(const string source, const char* optimizationFlags) {
    return createProgram(source, map<string, string>(), optimizationFlags);
536
537
}

538
cl::Program OpenCLContext::createProgram(const string source, const map<string, string>& defines, const char* optimizationFlags) {
Peter Eastman's avatar
Peter Eastman committed
539
    string options = (optimizationFlags == NULL ? defaultOptimizationOptions : string(optimizationFlags));
540
541
542
543
544
545
546
547
548
549
550
    stringstream src;
    if (!options.empty())
        src << "// Compilation Options: " << options << endl << endl;
    for (map<string, string>::const_iterator iter = compilationDefines.begin(); iter != compilationDefines.end(); ++iter) {
        src << "#define " << iter->first;
        if (!iter->second.empty())
            src << " " << iter->second;
        src << endl;
    }
    if (!compilationDefines.empty())
        src << endl;
551
552
553
554
555
    if (supportsDoublePrecision)
        src << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
    if (useDoublePrecision) {
        src << "typedef double real;\n";
        src << "typedef double2 real2;\n";
556
        src << "typedef double3 real3;\n";
557
558
559
560
561
        src << "typedef double4 real4;\n";
    }
    else {
        src << "typedef float real;\n";
        src << "typedef float2 real2;\n";
562
        src << "typedef float3 real3;\n";
563
564
565
566
567
        src << "typedef float4 real4;\n";
    }
    if (useDoublePrecision || useMixedPrecision) {
        src << "typedef double mixed;\n";
        src << "typedef double2 mixed2;\n";
568
        src << "typedef double3 mixed3;\n";
569
570
571
572
573
        src << "typedef double4 mixed4;\n";
    }
    else {
        src << "typedef float mixed;\n";
        src << "typedef float2 mixed2;\n";
574
        src << "typedef float3 mixed3;\n";
575
576
        src << "typedef float4 mixed4;\n";
    }
577
578
579
580
581
582
583
584
585
586
587
588
589
    for (map<string, string>::const_iterator iter = defines.begin(); iter != defines.end(); ++iter) {
        src << "#define " << iter->first;
        if (!iter->second.empty())
            src << " " << iter->second;
        src << endl;
    }
    if (!defines.empty())
        src << endl;
    src << source << endl;
    // Get length before using c_str() to avoid length() call invalidating the c_str() value.
    string src_string = src.str();
    ::size_t src_length = src_string.length();
    cl::Program::Sources sources(1, make_pair(src_string.c_str(), src_length));
590
591
    cl::Program program(context, sources);
    try {
592
        program.build(vector<cl::Device>(1, device), options.c_str());
593
594
595
596
597
598
    } catch (cl::Error err) {
        throw OpenMMException("Error compiling kernel: "+program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(device));
    }
    return program;
}

599
600
601
602
603
604
605
606
607
608
609
610
cl::CommandQueue& OpenCLContext::getQueue() {
    return currentQueue;
}

void OpenCLContext::setQueue(cl::CommandQueue& queue) {
    currentQueue = queue;
}

void OpenCLContext::restoreDefaultQueue() {
    currentQueue = defaultQueue;
}

611
string OpenCLContext::doubleToString(double value) const {
612
613
614
615
616
617
618
619
    stringstream s;
    s.precision(useDoublePrecision ? 16 : 8);
    s << scientific << value;
    if (!useDoublePrecision)
        s << "f";
    return s.str();
}

620
string OpenCLContext::intToString(int value) const {
621
622
623
624
625
    stringstream s;
    s << value;
    return s.str();
}

626
627
628
629
void OpenCLContext::executeKernel(cl::Kernel& kernel, int workUnits, int blockSize) {
    if (blockSize == -1)
        blockSize = ThreadBlockSize;
    int size = std::min((workUnits+blockSize-1)/blockSize, numThreadBlocks)*blockSize;
630
    try {
631
        currentQueue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(size), cl::NDRange(blockSize));
632
633
634
    }
    catch (cl::Error err) {
        stringstream str;
635
        str<<"Error invoking kernel "<<kernel.getInfo<CL_KERNEL_FUNCTION_NAME>()<<": "<<err.what()<<" ("<<err.err()<<")";
636
637
638
639
        throw OpenMMException(str.str());
    }
}

640
void OpenCLContext::clearBuffer(OpenCLArray& array) {
641
    clearBuffer(array.getDeviceBuffer(), array.getSize()*array.getElementSize());
642
643
}

644
void OpenCLContext::clearBuffer(cl::Memory& memory, int size) {
645
    int words = size/4;
646
    clearBufferKernel.setArg<cl::Memory>(0, memory);
647
648
649
650
651
652
    clearBufferKernel.setArg<cl_int>(1, words);
    executeKernel(clearBufferKernel, words, 128);
}

void OpenCLContext::addAutoclearBuffer(OpenCLArray& array) {
    addAutoclearBuffer(array.getDeviceBuffer(), array.getSize()*array.getElementSize());
653
654
}

655
656
void OpenCLContext::addAutoclearBuffer(cl::Memory& memory, int size) {
    autoclearBuffers.push_back(&memory);
657
    autoclearBufferSizes.push_back(size/4);
658
659
660
661
662
}

void OpenCLContext::clearAutoclearBuffers() {
    int base = 0;
    int total = autoclearBufferSizes.size();
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
    while (total-base >= 6) {
        clearSixBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearSixBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearSixBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearSixBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearSixBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearSixBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearSixBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearSixBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearSixBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearSixBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        clearSixBuffersKernel.setArg<cl::Memory>(10, *autoclearBuffers[base+5]);
        clearSixBuffersKernel.setArg<cl_int>(11, autoclearBufferSizes[base+5]);
        executeKernel(clearSixBuffersKernel, max(max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), autoclearBufferSizes[base+5]), 128);
        base += 6;
    }
    if (total-base == 5) {
        clearFiveBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFiveBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFiveBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFiveBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFiveBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFiveBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFiveBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFiveBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearFiveBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearFiveBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        executeKernel(clearFiveBuffersKernel, max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), 128);
    }
    else if (total-base == 4) {
693
694
695
696
697
698
699
700
        clearFourBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFourBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFourBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFourBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFourBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFourBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFourBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFourBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
701
        executeKernel(clearFourBuffersKernel, max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), 128);
702
    }
703
    else if (total-base == 3) {
704
705
706
707
708
709
        clearThreeBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearThreeBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearThreeBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearThreeBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearThreeBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearThreeBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
710
        executeKernel(clearThreeBuffersKernel, max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), 128);
711
712
713
714
715
716
    }
    else if (total-base == 2) {
        clearTwoBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearTwoBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearTwoBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearTwoBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
717
        executeKernel(clearTwoBuffersKernel, max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), 128);
718
719
720
721
722
723
    }
    else if (total-base == 1) {
        clearBuffer(*autoclearBuffers[base], autoclearBufferSizes[base]);
    }
}

724
725
726
727
728
729
730
void OpenCLContext::reduceForces() {
    if (supports64BitGlobalAtomics)
        executeKernel(reduceForcesKernel, paddedNumAtoms, 128);
    else
        reduceBuffer(*forceBuffers, numForceBuffers);
}

731
void OpenCLContext::reduceBuffer(OpenCLArray& array, int numBuffers) {
732
    int bufferSize = array.getSize()/numBuffers;
733
734
735
736
    reduceReal4Kernel.setArg<cl::Buffer>(0, array.getDeviceBuffer());
    reduceReal4Kernel.setArg<cl_int>(1, bufferSize);
    reduceReal4Kernel.setArg<cl_int>(2, numBuffers);
    executeKernel(reduceReal4Kernel, bufferSize, 128);
737
}
738

739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
/**
 * This class ensures that atom reordering doesn't break virtual sites.
 */
class OpenCLContext::VirtualSiteInfo : public OpenCLForceInfo {
public:
    VirtualSiteInfo(const System& system) : OpenCLForceInfo(0) {
        for (int i = 0; i < system.getNumParticles(); i++) {
            if (system.isVirtualSite(i)) {
                siteTypes.push_back(&typeid(system.getVirtualSite(i)));
                vector<int> particles;
                particles.push_back(i);
                for (int j = 0; j < system.getVirtualSite(i).getNumParticles(); j++)
                    particles.push_back(system.getVirtualSite(i).getParticle(j));
                siteParticles.push_back(particles);
                vector<double> weights;
                if (dynamic_cast<const TwoParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) {
                    // A two particle average.

                    const TwoParticleAverageSite& site = dynamic_cast<const TwoParticleAverageSite&>(system.getVirtualSite(i));
                    weights.push_back(site.getWeight(0));
                    weights.push_back(site.getWeight(1));
                }
                else if (dynamic_cast<const ThreeParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) {
                    // A three particle average.

                    const ThreeParticleAverageSite& site = dynamic_cast<const ThreeParticleAverageSite&>(system.getVirtualSite(i));
                    weights.push_back(site.getWeight(0));
                    weights.push_back(site.getWeight(1));
                    weights.push_back(site.getWeight(2));
                }
                else if (dynamic_cast<const OutOfPlaneSite*>(&system.getVirtualSite(i)) != NULL) {
                    // An out of plane site.

                    const OutOfPlaneSite& site = dynamic_cast<const OutOfPlaneSite&>(system.getVirtualSite(i));
                    weights.push_back(site.getWeight12());
                    weights.push_back(site.getWeight13());
                    weights.push_back(site.getWeightCross());
                }
                siteWeights.push_back(weights);
            }
        }
    }
    int getNumParticleGroups() {
        return siteTypes.size();
    }
    void getParticlesInGroup(int index, std::vector<int>& particles) {
        particles = siteParticles[index];
    }
    bool areGroupsIdentical(int group1, int group2) {
        if (siteTypes[group1] != siteTypes[group2])
            return false;
        int numParticles = siteWeights[group1].size();
        if (siteWeights[group2].size() != numParticles)
            return false;
        for (int i = 0; i < numParticles; i++)
            if (siteWeights[group1][i] != siteWeights[group2][i])
                return false;
        return true;
    }
private:
    vector<const type_info*> siteTypes;
    vector<vector<int> > siteParticles;
    vector<vector<double> > siteWeights;
};


805
806
void OpenCLContext::findMoleculeGroups() {
    // The first time this is called, we need to identify all the molecules in the system.
807

808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
    if (moleculeGroups.size() == 0) {
        // Add a ForceInfo that makes sure reordering doesn't break virtual sites.

        addForce(new VirtualSiteInfo(system));

        // First make a list of every other atom to which each atom is connect by a constraint or force group.

        vector<vector<int> > atomBonds(system.getNumParticles());
        for (int i = 0; i < system.getNumConstraints(); i++) {
            int particle1, particle2;
            double distance;
            system.getConstraintParameters(i, particle1, particle2, distance);
            atomBonds[particle1].push_back(particle2);
            atomBonds[particle2].push_back(particle1);
        }
        for (int i = 0; i < (int) forces.size(); i++) {
            for (int j = 0; j < forces[i]->getNumParticleGroups(); j++) {
                vector<int> particles;
                forces[i]->getParticlesInGroup(j, particles);
                for (int k = 0; k < (int) particles.size(); k++)
                    for (int m = 0; m < (int) particles.size(); m++)
                        if (k != m)
                            atomBonds[particles[k]].push_back(particles[m]);
            }
832
833
        }

834
        // Now identify atoms by which molecule they belong to.
835

836
837
838
839
840
841
        vector<vector<int> > atomIndices = ContextImpl::findMolecules(numAtoms, atomBonds);
        int numMolecules = atomIndices.size();
        vector<int> atomMolecule(numAtoms);
        for (int i = 0; i < (int) atomIndices.size(); i++)
            for (int j = 0; j < (int) atomIndices[i].size(); j++)
                atomMolecule[atomIndices[i][j]] = i;
842

843
        // Construct a description of each molecule.
844

845
846
847
848
        molecules.resize(numMolecules);
        for (int i = 0; i < numMolecules; i++) {
            molecules[i].atoms = atomIndices[i];
            molecules[i].groups.resize(forces.size());
849
        }
850
851
852
853
854
855
856
857
858
859
860
861
862
        for (int i = 0; i < system.getNumConstraints(); i++) {
            int particle1, particle2;
            double distance;
            system.getConstraintParameters(i, particle1, particle2, distance);
            molecules[atomMolecule[particle1]].constraints.push_back(i);
        }
        for (int i = 0; i < (int) forces.size(); i++)
            for (int j = 0; j < forces[i]->getNumParticleGroups(); j++) {
                vector<int> particles;
                forces[i]->getParticlesInGroup(j, particles);
                molecules[atomMolecule[particles[0]]].groups[i].push_back(j);
            }
    }
863
864
865
866
867

    // Sort them into groups of identical molecules.

    vector<Molecule> uniqueMolecules;
    vector<vector<int> > moleculeInstances;
868
    vector<vector<int> > moleculeOffsets;
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
    for (int molIndex = 0; molIndex < (int) molecules.size(); molIndex++) {
        Molecule& mol = molecules[molIndex];

        // See if it is identical to another molecule.

        bool isNew = true;
        for (int j = 0; j < (int) uniqueMolecules.size() && isNew; j++) {
            Molecule& mol2 = uniqueMolecules[j];
            bool identical = (mol.atoms.size() == mol2.atoms.size() && mol.constraints.size() == mol2.constraints.size());

            // See if the atoms are identical.

            int atomOffset = mol2.atoms[0]-mol.atoms[0];
            for (int i = 0; i < (int) mol.atoms.size() && identical; i++) {
                if (mol.atoms[i] != mol2.atoms[i]-atomOffset || system.getParticleMass(mol.atoms[i]) != system.getParticleMass(mol2.atoms[i]))
                    identical = false;
885
                for (int k = 0; k < (int) forces.size(); k++)
886
887
888
                    if (!forces[k]->areParticlesIdentical(mol.atoms[i], mol2.atoms[i]))
                        identical = false;
            }
889

890
891
892
893
894
895
896
            // See if the constraints are identical.

            for (int i = 0; i < (int) mol.constraints.size() && identical; i++) {
                int c1particle1, c1particle2, c2particle1, c2particle2;
                double distance1, distance2;
                system.getConstraintParameters(mol.constraints[i], c1particle1, c1particle2, distance1);
                system.getConstraintParameters(mol2.constraints[i], c2particle1, c2particle2, distance2);
897
                if (c1particle1 != c2particle1-atomOffset || c1particle2 != c2particle2-atomOffset || distance1 != distance2)
898
899
900
901
902
                    identical = false;
            }

            // See if the force groups are identical.

903
            for (int i = 0; i < (int) forces.size() && identical; i++) {
904
905
                if (mol.groups[i].size() != mol2.groups[i].size())
                    identical = false;
906
                for (int k = 0; k < (int) mol.groups[i].size() && identical; k++)
907
908
909
910
                    if (!forces[i]->areGroupsIdentical(mol.groups[i][k], mol2.groups[i][k]))
                        identical = false;
            }
            if (identical) {
911
912
                moleculeInstances[j].push_back(molIndex);
                moleculeOffsets[j].push_back(mol.atoms[0]);
913
914
915
916
917
918
                isNew = false;
            }
        }
        if (isNew) {
            uniqueMolecules.push_back(mol);
            moleculeInstances.push_back(vector<int>());
919
920
921
            moleculeInstances[moleculeInstances.size()-1].push_back(molIndex);
            moleculeOffsets.push_back(vector<int>());
            moleculeOffsets[moleculeOffsets.size()-1].push_back(mol.atoms[0]);
922
923
924
925
926
927
        }
    }
    moleculeGroups.resize(moleculeInstances.size());
    for (int i = 0; i < (int) moleculeInstances.size(); i++)
    {
        moleculeGroups[i].instances = moleculeInstances[i];
928
        moleculeGroups[i].offsets = moleculeOffsets[i];
929
930
931
932
933
934
935
        vector<int>& atoms = uniqueMolecules[i].atoms;
        moleculeGroups[i].atoms.resize(atoms.size());
        for (int j = 0; j < (int) atoms.size(); j++)
            moleculeGroups[i].atoms[j] = atoms[j]-atoms[0];
    }
}

936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
void OpenCLContext::invalidateMolecules() {
    if (numAtoms == 0 || nonbonded == NULL || !nonbonded->getUseCutoff())
        return;
    bool valid = true;
    for (int group = 0; valid && group < (int) moleculeGroups.size(); group++) {
        MoleculeGroup& mol = moleculeGroups[group];
        vector<int>& instances = mol.instances;
        vector<int>& offsets = mol.offsets;
        vector<int>& atoms = mol.atoms;
        int numMolecules = instances.size();
        Molecule& m1 = molecules[instances[0]];
        int offset1 = offsets[0];
        for (int j = 1; valid && j < numMolecules; j++) {
            // See if the atoms are identical.

            Molecule& m2 = molecules[instances[j]];
            int offset2 = offsets[j];
            for (int i = 0; i < (int) atoms.size() && valid; i++) {
                for (int k = 0; k < (int) forces.size(); k++)
                    if (!forces[k]->areParticlesIdentical(atoms[i]+offset1, atoms[i]+offset2))
                        valid = false;
            }

            // See if the force groups are identical.

            for (int i = 0; i < (int) forces.size() && valid; i++) {
                for (int k = 0; k < (int) m1.groups[i].size() && valid; k++)
                    if (!forces[i]->areGroupsIdentical(m1.groups[i][k], m2.groups[i][k]))
                        valid = false;
            }
        }
    }
    if (valid)
        return;
970

971
972
973
    // The list of which molecules are identical is no longer valid.  We need to restore the
    // atoms to their original order, rebuild the list of identical molecules, and sort them
    // again.
974

975
    vector<mm_int4> newCellOffsets(numAtoms);
976
977
    if (useDoublePrecision) {
        vector<mm_double4> oldPosq(paddedNumAtoms);
978
        vector<mm_double4> newPosq(paddedNumAtoms, mm_double4(0,0,0,0));
979
        vector<mm_double4> oldVelm(paddedNumAtoms);
980
        vector<mm_double4> newVelm(paddedNumAtoms, mm_double4(0,0,0,0));
981
982
983
984
985
986
987
988
989
990
991
992
993
        posq->download(oldPosq);
        velm->download(oldVelm);
        for (int i = 0; i < numAtoms; i++) {
            int index = atomIndex[i];
            newPosq[index] = oldPosq[i];
            newVelm[index] = oldVelm[i];
            newCellOffsets[index] = posCellOffsets[i];
        }
        posq->upload(newPosq);
        velm->upload(newVelm);
    }
    else if (useMixedPrecision) {
        vector<mm_float4> oldPosq(paddedNumAtoms);
994
        vector<mm_float4> newPosq(paddedNumAtoms, mm_float4(0,0,0,0));
995
        vector<mm_float4> oldPosqCorrection(paddedNumAtoms);
996
        vector<mm_float4> newPosqCorrection(paddedNumAtoms, mm_float4(0,0,0,0));
997
        vector<mm_double4> oldVelm(paddedNumAtoms);
998
        vector<mm_double4> newVelm(paddedNumAtoms, mm_double4(0,0,0,0));
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        posq->download(oldPosq);
        velm->download(oldVelm);
        for (int i = 0; i < numAtoms; i++) {
            int index = atomIndex[i];
            newPosq[index] = oldPosq[i];
            newPosqCorrection[index] = oldPosqCorrection[i];
            newVelm[index] = oldVelm[i];
            newCellOffsets[index] = posCellOffsets[i];
        }
        posq->upload(newPosq);
Peter Eastman's avatar
Peter Eastman committed
1009
        posqCorrection->upload(newPosqCorrection);
1010
1011
1012
1013
        velm->upload(newVelm);
    }
    else {
        vector<mm_float4> oldPosq(paddedNumAtoms);
1014
        vector<mm_float4> newPosq(paddedNumAtoms, mm_float4(0,0,0,0));
1015
        vector<mm_float4> oldVelm(paddedNumAtoms);
1016
        vector<mm_float4> newVelm(paddedNumAtoms, mm_float4(0,0,0,0));
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
        posq->download(oldPosq);
        velm->download(oldVelm);
        for (int i = 0; i < numAtoms; i++) {
            int index = atomIndex[i];
            newPosq[index] = oldPosq[i];
            newVelm[index] = oldVelm[i];
            newCellOffsets[index] = posCellOffsets[i];
        }
        posq->upload(newPosq);
        velm->upload(newVelm);
1027
1028
    }
    for (int i = 0; i < numAtoms; i++) {
1029
        atomIndex[i] = i;
1030
1031
        posCellOffsets[i] = newCellOffsets[i];
    }
1032
    atomIndexDevice->upload(atomIndex);
1033
    findMoleculeGroups();
1034
1035
    for (int i = 0; i < (int) reorderListeners.size(); i++)
        reorderListeners[i]->execute();
1036
    reorderAtoms();
1037
1038
}

1039
1040
void OpenCLContext::reorderAtoms() {
    atomsWereReordered = false;
Peter Eastman's avatar
Peter Eastman committed
1041
    if (numAtoms == 0 || nonbonded == NULL || !nonbonded->getUseCutoff() || stepsSinceReorder < 250) {
1042
        stepsSinceReorder++;
1043
        return;
1044
    }
Peter Eastman's avatar
Peter Eastman committed
1045
    atomsWereReordered = true;
1046
    stepsSinceReorder = 0;
1047
    if (useDoublePrecision)
1048
        reorderAtomsImpl<cl_double, mm_double4, cl_double, mm_double4>();
1049
    else if (useMixedPrecision)
1050
        reorderAtomsImpl<cl_float, mm_float4, cl_double, mm_double4>();
1051
    else
1052
        reorderAtomsImpl<cl_float, mm_float4, cl_float, mm_float4>();
1053
1054
1055
}

template <class Real, class Real4, class Mixed, class Mixed4>
1056
void OpenCLContext::reorderAtomsImpl() {
1057
1058
1059

    // Find the range of positions and the number of bins along each axis.

1060
1061
1062
    vector<Real4> oldPosq(paddedNumAtoms);
    vector<Real4> oldPosqCorrection(paddedNumAtoms);
    vector<Mixed4> oldVelm(paddedNumAtoms);
1063
1064
    posq->download(oldPosq);
    velm->download(oldVelm);
1065
1066
1067
1068
1069
    if (useMixedPrecision)
        posqCorrection->download(oldPosqCorrection);
    Real minx = oldPosq[0].x, maxx = oldPosq[0].x;
    Real miny = oldPosq[0].y, maxy = oldPosq[0].y;
    Real minz = oldPosq[0].z, maxz = oldPosq[0].z;
1070
1071
    if (nonbonded->getUsePeriodic()) {
        minx = miny = minz = 0.0;
1072
1073
1074
        maxx = periodicBoxSizeDouble.x;
        maxy = periodicBoxSizeDouble.y;
        maxz = periodicBoxSizeDouble.z;
1075
1076
1077
    }
    else {
        for (int i = 1; i < numAtoms; i++) {
1078
            const Real4& pos = oldPosq[i];
1079
1080
1081
1082
1083
1084
            minx = min(minx, pos.x);
            maxx = max(maxx, pos.x);
            miny = min(miny, pos.y);
            maxy = max(maxy, pos.y);
            minz = min(minz, pos.z);
            maxz = max(maxz, pos.z);
1085
1086
1087
1088
1089
1090
        }
    }

    // Loop over each group of identical molecules and reorder them.

    vector<int> originalIndex(numAtoms);
1091
1092
1093
    vector<Real4> newPosq(paddedNumAtoms, Real4(0,0,0,0));
    vector<Real4> newPosqCorrection(paddedNumAtoms, Real4(0,0,0,0));
    vector<Mixed4> newVelm(paddedNumAtoms, Mixed4(0,0,0,0));
1094
1095
1096
1097
1098
    vector<mm_int4> newCellOffsets(numAtoms);
    for (int group = 0; group < (int) moleculeGroups.size(); group++) {
        // Find the center of each molecule.

        MoleculeGroup& mol = moleculeGroups[group];
1099
        int numMolecules = mol.offsets.size();
1100
        vector<int>& atoms = mol.atoms;
1101
1102
        vector<Real4> molPos(numMolecules);
        Real invNumAtoms = (Real) (1.0/atoms.size());
1103
1104
1105
1106
1107
        for (int i = 0; i < numMolecules; i++) {
            molPos[i].x = 0.0f;
            molPos[i].y = 0.0f;
            molPos[i].z = 0.0f;
            for (int j = 0; j < (int)atoms.size(); j++) {
1108
                int atom = atoms[j]+mol.offsets[i];
1109
                const Real4& pos = oldPosq[atom];
1110
1111
1112
                molPos[i].x += pos.x;
                molPos[i].y += pos.y;
                molPos[i].z += pos.z;
1113
            }
1114
1115
1116
            molPos[i].x *= invNumAtoms;
            molPos[i].y *= invNumAtoms;
            molPos[i].z *= invNumAtoms;
1117
1118
            if (molPos[i].x != molPos[i].x)
                throw OpenMMException("Particle coordinate is nan");
1119
1120
1121
1122
1123
        }
        if (nonbonded->getUsePeriodic()) {
            // Move each molecule position into the same box.

            for (int i = 0; i < numMolecules; i++) {
Peter Eastman's avatar
Peter Eastman committed
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
                Real4 center = molPos[i];
                int zcell = (int) floor(center.z*invPeriodicBoxSize.z);
                center.x -= zcell*periodicBoxVecZ.x;
                center.y -= zcell*periodicBoxVecZ.y;
                center.z -= zcell*periodicBoxVecZ.z;
                int ycell = (int) floor(center.y*invPeriodicBoxSize.y);
                center.x -= ycell*periodicBoxVecY.x;
                center.y -= ycell*periodicBoxVecY.y;
                int xcell = (int) floor(center.x*invPeriodicBoxSize.x);
                center.x -= xcell*periodicBoxVecX.x;
                if (xcell != 0 || ycell != 0 || zcell != 0) {
                    Real dx = molPos[i].x-center.x;
                    Real dy = molPos[i].y-center.y;
                    Real dz = molPos[i].z-center.z;
                    molPos[i] = center;
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
                    for (int j = 0; j < (int) atoms.size(); j++) {
                        int atom = atoms[j]+mol.offsets[i];
                        Real4 p = oldPosq[atom];
                        p.x -= dx;
                        p.y -= dy;
                        p.z -= dz;
                        oldPosq[atom] = p;
                        posCellOffsets[atom].x -= xcell;
                        posCellOffsets[atom].y -= ycell;
                        posCellOffsets[atom].z -= zcell;
1149
1150
1151
1152
1153
1154
1155
1156
                    }
                }
            }
        }

        // Select a bin for each molecule, then sort them by bin.

        bool useHilbert = (numMolecules > 5000 || atoms.size() > 8); // For small systems, a simple zigzag curve works better than a Hilbert curve.
1157
        Real binWidth;
1158
        if (useHilbert)
1159
            binWidth = (Real) (max(max(maxx-minx, maxy-miny), maxz-minz)/255.0);
1160
        else
1161
            binWidth = (Real) (0.2*nonbonded->getMaxCutoffDistance());
1162
        Real invBinWidth = (Real) (1.0/binWidth);
1163
1164
        int xbins = 1 + (int) ((maxx-minx)*invBinWidth);
        int ybins = 1 + (int) ((maxy-miny)*invBinWidth);
1165
1166
1167
        vector<pair<int, int> > molBins(numMolecules);
        bitmask_t coords[3];
        for (int i = 0; i < numMolecules; i++) {
1168
1169
1170
            int x = (int) ((molPos[i].x-minx)*invBinWidth);
            int y = (int) ((molPos[i].y-miny)*invBinWidth);
            int z = (int) ((molPos[i].z-minz)*invBinWidth);
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
            int bin;
            if (useHilbert) {
                coords[0] = x;
                coords[1] = y;
                coords[2] = z;
                bin = (int) hilbert_c2i(3, 8, coords);
            }
            else {
                int yodd = y&1;
                int zodd = z&1;
                bin = z*xbins*ybins;
                bin += (zodd ? ybins-y : y)*xbins;
                bin += (yodd ? xbins-x : x);
            }
            molBins[i] = pair<int, int>(bin, i);
        }
        sort(molBins.begin(), molBins.end());

        // Reorder the atoms.

        for (int i = 0; i < numMolecules; i++) {
            for (int j = 0; j < (int)atoms.size(); j++) {
1193
1194
                int oldIndex = mol.offsets[molBins[i].second]+atoms[j];
                int newIndex = mol.offsets[i]+atoms[j];
1195
1196
                originalIndex[newIndex] = atomIndex[oldIndex];
                newPosq[newIndex] = oldPosq[oldIndex];
1197
1198
                if (useMixedPrecision)
                    newPosqCorrection[newIndex] = oldPosqCorrection[oldIndex];
1199
                newVelm[newIndex] = oldVelm[oldIndex];
1200
1201
1202
1203
1204
1205
1206
1207
                newCellOffsets[newIndex] = posCellOffsets[oldIndex];
            }
        }
    }

    // Update the streams.

    for (int i = 0; i < numAtoms; i++) {
1208
        atomIndex[i] = originalIndex[i];
1209
1210
        posCellOffsets[i] = newCellOffsets[i];
    }
1211
    posq->upload(newPosq);
1212
1213
    if (useMixedPrecision)
        posqCorrection->upload(newPosqCorrection);
1214
1215
    velm->upload(newVelm);
    atomIndexDevice->upload(atomIndex);
1216
1217
1218
1219
1220
1221
    for (int i = 0; i < (int) reorderListeners.size(); i++)
        reorderListeners[i]->execute();
}

void OpenCLContext::addReorderListener(ReorderListener* listener) {
    reorderListeners.push_back(listener);
1222
}
1223

1224
1225
1226
1227
1228
1229
1230
1231
void OpenCLContext::addPreComputation(ForcePreComputation* computation) {
    preComputations.push_back(computation);
}

void OpenCLContext::addPostComputation(ForcePostComputation* computation) {
    postComputations.push_back(computation);
}

1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
struct OpenCLContext::WorkThread::ThreadData {
    ThreadData(std::queue<OpenCLContext::WorkTask*>& tasks, bool& waiting,  bool& finished,
            pthread_mutex_t& queueLock, pthread_cond_t& waitForTaskCondition, pthread_cond_t& queueEmptyCondition) :
        tasks(tasks), waiting(waiting), finished(finished), queueLock(queueLock),
        waitForTaskCondition(waitForTaskCondition), queueEmptyCondition(queueEmptyCondition) {
    }
    std::queue<OpenCLContext::WorkTask*>& tasks;
    bool& waiting;
    bool& finished;
    pthread_mutex_t& queueLock;
    pthread_cond_t& waitForTaskCondition;
    pthread_cond_t& queueEmptyCondition;
};

static void* threadBody(void* args) {
    OpenCLContext::WorkThread::ThreadData& data = *reinterpret_cast<OpenCLContext::WorkThread::ThreadData*>(args);
    while (!data.finished || data.tasks.size() > 0) {
        pthread_mutex_lock(&data.queueLock);
        while (data.tasks.empty() && !data.finished) {
            data.waiting = true;
            pthread_cond_signal(&data.queueEmptyCondition);
            pthread_cond_wait(&data.waitForTaskCondition, &data.queueLock);
        }
        OpenCLContext::WorkTask* task = NULL;
        if (!data.tasks.empty()) {
            data.waiting = false;
            task = data.tasks.front();
            data.tasks.pop();
        }
        pthread_mutex_unlock(&data.queueLock);
        if (task != NULL) {
            task->execute();
            delete task;
        }
    }
    data.waiting = true;
    pthread_cond_signal(&data.queueEmptyCondition);
    delete &data;
    return 0;
}

OpenCLContext::WorkThread::WorkThread() : waiting(true), finished(false) {
    pthread_mutex_init(&queueLock, NULL);
    pthread_cond_init(&waitForTaskCondition, NULL);
    pthread_cond_init(&queueEmptyCondition, NULL);
    ThreadData* data = new ThreadData(tasks, waiting, finished, queueLock, waitForTaskCondition, queueEmptyCondition);
    pthread_create(&thread, NULL, threadBody, data);
}

OpenCLContext::WorkThread::~WorkThread() {
    pthread_mutex_lock(&queueLock);
    finished = true;
    pthread_cond_broadcast(&waitForTaskCondition);
    pthread_mutex_unlock(&queueLock);
    pthread_join(thread, NULL);
    pthread_mutex_destroy(&queueLock);
    pthread_cond_destroy(&waitForTaskCondition);
    pthread_cond_destroy(&queueEmptyCondition);
}

void OpenCLContext::WorkThread::addTask(OpenCLContext::WorkTask* task) {
    pthread_mutex_lock(&queueLock);
    tasks.push(task);
    waiting = false;
    pthread_cond_signal(&waitForTaskCondition);
    pthread_mutex_unlock(&queueLock);
}

bool OpenCLContext::WorkThread::isWaiting() {
    return waiting;
}

bool OpenCLContext::WorkThread::isFinished() {
    return finished;
}

void OpenCLContext::WorkThread::flush() {
    pthread_mutex_lock(&queueLock);
    while (!waiting)
       pthread_cond_wait(&queueEmptyCondition, &queueLock);
    pthread_mutex_unlock(&queueLock);
}