OpenCLContext.cpp 44.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
 * Portions copyright (c) 2009 Stanford University and the Authors.           *
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
28
29
30
#ifdef WIN32
  #define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include <cmath>
31
32
#include "OpenCLContext.h"
#include "OpenCLArray.h"
Peter Eastman's avatar
Peter Eastman committed
33
#include "OpenCLBondedUtilities.h"
34
#include "OpenCLForceInfo.h"
35
#include "OpenCLIntegrationUtilities.h"
36
#include "OpenCLKernelSources.h"
37
#include "OpenCLNonbondedUtilities.h"
38
#include "hilbert.h"
39
#include "openmm/Platform.h"
40
#include "openmm/System.h"
41
#include "openmm/VirtualSite.h"
Peter Eastman's avatar
Peter Eastman committed
42
#include <algorithm>
43
44
#include <fstream>
#include <iostream>
45
#include <sstream>
46
#include <typeinfo>
47
48

using namespace OpenMM;
49
using namespace std;
50

51
52
53
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV 0x4000
#endif
54
55
56
#ifndef CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV
  #define CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV 0x4001
#endif
57

58
59
60
const int OpenCLContext::ThreadBlockSize = 64;
const int OpenCLContext::TileSize = 32;

61
static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_info, size_t cb, void* user_data) {
62
63
64
    string skip = "OpenCL Build Warning : Compiler build log:";
    if (strncmp(errinfo, skip.c_str(), skip.length()) == 0)
        return; // OS X Lion insists on calling this for every build warning, even though they aren't errors.
65
66
67
    std::cerr << "OpenCL internal error: " << errinfo << std::endl;
}

68
OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData) :
Peter Eastman's avatar
Peter Eastman committed
69
70
        time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), atomsWereReordered(false), posq(NULL),
        velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndex(NULL), integration(NULL),
Peter Eastman's avatar
Peter Eastman committed
71
        bonded(NULL), nonbonded(NULL), thread(NULL) {
72
    try {
73
        contextIndex = platformData.contexts.size();
74
75
        std::vector<cl::Platform> platforms;
        cl::Platform::get(&platforms);
76
77
        if (platformIndex < 0 || platformIndex >= platforms.size())
            throw OpenMMException("Illegal value for OpenCL platform index");
78
79
        vector<cl::Device> devices;
        platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices);
80
        const int minThreadBlockSize = 32;
81
        if (deviceIndex < 0 || deviceIndex >= (int) devices.size()) {
82
            // Try to figure out which device is the fastest.
83

84
            int bestSpeed = -1;
85
            for (int i = 0; i < (int) devices.size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
86
87
                if (platforms[platformIndex].getInfo<CL_PLATFORM_VENDOR>() == "Apple" && devices[i].getInfo<CL_DEVICE_VENDOR>() == "AMD")
                    continue; // Don't use AMD GPUs on OS X due to serious bugs.
88
                int maxSize = devices[i].getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>()[0];
89
90
91
92
93
                int processingElementsPerComputeUnit = 8;
                if (devices[i].getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                    processingElementsPerComputeUnit = 1;
                }
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
94
95
96
97
                    cl_uint computeCapabilityMajor;
                    clGetDeviceInfo(devices[i](), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
                    processingElementsPerComputeUnit = (computeCapabilityMajor < 2 ? 8 : 32);
                }
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
                else if (devices[i].getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime (it may be an older runtime,
                    // or the CPU device) so still have to check for errors.
                    try {
                        processingElementsPerComputeUnit =
                            // AMD GPUs either have a single VLIW SIMD or multiple scalar SIMDs.
                            // The SIMD width is the number of threads the SIMD executes per cycle.
                            // This will be less than the wavefront width since it takes several
                            // cycles to execute the full wavefront.
                            // The SIMD instruction width is the VLIW instruction width (or 1 for scalar),
                            // this is the number of ALUs that can be executing per instruction per thread. 
                            devices[i].getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_WIDTH_AMD>() *
                            devices[i].getInfo<CL_DEVICE_SIMD_INSTRUCTION_WIDTH_AMD>();
                        // Just in case any of the queries return 0.
                        if (processingElementsPerComputeUnit <= 0)
                            processingElementsPerComputeUnit = 1;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the queries so use default.
                    }
                }
120
121
                int speed = devices[i].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>()*processingElementsPerComputeUnit*devices[i].getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
                if (maxSize >= minThreadBlockSize && speed > bestSpeed) {
122
                    deviceIndex = i;
123
124
                    bestSpeed = speed;
                }
125
            }
126
        }
127
128
129
        if (deviceIndex == -1)
            throw OpenMMException("No compatible OpenCL device is available");
        device = devices[deviceIndex];
Peter Eastman's avatar
Peter Eastman committed
130
        this->deviceIndex = deviceIndex;
131
        if (device.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>() < minThreadBlockSize)
132
            throw OpenMMException("The specified OpenCL device is not compatible with OpenMM");
133
        compilationDefines["WORK_GROUP_SIZE"] = OpenCLExpressionUtilities::intToString(ThreadBlockSize);
134
        defaultOptimizationOptions = "-cl-fast-relaxed-math";
135
        supports64BitGlobalAtomics = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_int64_base_atomics") != string::npos);
136
        supportsDoublePrecision = (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_khr_fp64") != string::npos);
137
        string vendor = device.getInfo<CL_DEVICE_VENDOR>();
138
        int numThreadBlocksPerComputeUnit = 6;
139
        if (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA") {
140
            compilationDefines["WARPS_ARE_ATOMIC"] = "";
141
            simdWidth = 32;
142
143
            if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_nv_device_attribute_query") != string::npos) {
                // Compute level 1.2 and later Nvidia GPUs support 64 bit atomics, even though they don't list the
144
145
                // proper extension as supported.  We only use them on compute level 2.0 or later, since they're very
                // slow on earlier GPUs.
146

147
                cl_uint computeCapabilityMajor;
148
                clGetDeviceInfo(device(), CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV, sizeof(cl_uint), &computeCapabilityMajor, NULL);
149
                if (computeCapabilityMajor > 1)
150
151
                    supports64BitGlobalAtomics = true;
            }
152
        }
153
        else if (vendor.size() >= 28 && vendor.substr(0, 28) == "Advanced Micro Devices, Inc.") {
154
155
156
            // Disable 64 bit atomics.  A future version of the driver will support them, but until we can test that,
            // it's safest not to use them.
            supports64BitGlobalAtomics = false;
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            if (device.getInfo<CL_DEVICE_TYPE>() != CL_DEVICE_TYPE_GPU) {
                /// \todo Is 6 a good value for the OpenCL CPU device?
                // numThreadBlocksPerComputeUnit = ?;
                simdWidth = 1;
            }
            else {
                bool amdPostSdk2_4 = false;
                // Default to 1 which will use the default kernels.
                simdWidth = 1;
                if (device.getInfo<CL_DEVICE_EXTENSIONS>().find("cl_amd_device_attribute_query") != string::npos) {
                    // This attribute does not ensure that all queries are supported by the runtime so still have to
                    // check for errors.
                    try {
                        // AMD has both 32 and 64 width SIMDs. Can determine by using:
                        // simdWidth = device.getInfo<CL_DEVICE_WAVEFRONT_WIDTH_AMD>();
                        // Must catch cl:Error as will fail if runtime does not support queries.
                        // However, the 32 width NVIDIA kernels do not have all the necessary
                        // barriers and so will not work for AMD.
                        // So for now leave default of 1 which will use the default kernels.

                        cl_uint simdPerComputeUnit = device.getInfo<CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD>();
                        // If the GPU has multiple SIMDs per compute unit then it is uses the scalar instruction
                        // set instead of the VLIW instruction set. It therefore needs more thread blocks per
                        // compute unit to hide memory latency.
                        if (simdPerComputeUnit > 1)
                            numThreadBlocksPerComputeUnit = 4 * simdPerComputeUnit;

                        // If the queries are supported then must be newer than SDK 2.4.
                        amdPostSdk2_4 = true;
                    }
                    catch (cl::Error err) {
                        // Runtime does not support the query so is unlikely to be the newer scalar GPU.
                        // Stay with the default simdWidth and numThreadBlocksPerComputeUnit.
                    }
                }
                // AMD APP SDK 2.4 has a performance problem with atomics. Enable the work around. This is fixed after SDK 2.4.
                if (!amdPostSdk2_4)
                    compilationDefines["AMD_ATOMIC_WORK_AROUND"] = "";
            }
196
        }
197
198
        else
            simdWidth = 1;
Peter Eastman's avatar
Peter Eastman committed
199
        if (platforms[platformIndex].getInfo<CL_PLATFORM_VENDOR>() == "Apple" && vendor == "AMD")
200
            compilationDefines["MAC_AMD_WORKAROUND"] = "";
201
        if (supports64BitGlobalAtomics)
202
            compilationDefines["SUPPORTS_64_BIT_ATOMICS"] = "";
203
204
        if (supportsDoublePrecision)
            compilationDefines["SUPPORTS_DOUBLE_PRECISION"] = "";
205
206
207
208
        vector<cl::Device> contextDevices;
        contextDevices.push_back(device);
        cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[platformIndex](), 0};
        context = cl::Context(contextDevices, cprops, errorCallback);
209
210
211
212
        queue = cl::CommandQueue(context, device);
        numAtoms = numParticles;
        paddedNumAtoms = TileSize*((numParticles+TileSize-1)/TileSize);
        numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
213
        numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
Peter Eastman's avatar
Peter Eastman committed
214
        bonded = new OpenCLBondedUtilities(*this);
215
216
217
        nonbonded = new OpenCLNonbondedUtilities(*this);
        posq = new OpenCLArray<mm_float4>(*this, paddedNumAtoms, "posq", true);
        velm = new OpenCLArray<mm_float4>(*this, paddedNumAtoms, "velm", true);
218
        posCellOffsets.resize(paddedNumAtoms, mm_int4(0, 0, 0, 0));
219
220
221
222
223
    }
    catch (cl::Error err) {
        std::stringstream str;
        str<<"Error initializing context: "<<err.what()<<" ("<<err.err()<<")";
        throw OpenMMException(str.str());
224
    }
225
226
227

    // Create utility kernels that are used in multiple places.

228
    utilities = createProgram(OpenCLKernelSources::utilities);
229
    clearBufferKernel = cl::Kernel(utilities, "clearBuffer");
230
231
232
    clearTwoBuffersKernel = cl::Kernel(utilities, "clearTwoBuffers");
    clearThreeBuffersKernel = cl::Kernel(utilities, "clearThreeBuffers");
    clearFourBuffersKernel = cl::Kernel(utilities, "clearFourBuffers");
233
234
    clearFiveBuffersKernel = cl::Kernel(utilities, "clearFiveBuffers");
    clearSixBuffersKernel = cl::Kernel(utilities, "clearSixBuffers");
235
    reduceFloat4Kernel = cl::Kernel(utilities, "reduceFloat4Buffer");
236
    reduceForcesKernel = cl::Kernel(utilities, "reduceForces");
237
238
239
240

    // Decide whether native_sqrt(), native_rsqrt(), and native_recip() are sufficiently accurate to use.

    cl::Kernel accuracyKernel(utilities, "determineNativeAccuracy");
241
    OpenCLArray<mm_float8> values(*this, 20, "values", true);
242
    float nextValue = 1e-4f;
243
    for (int i = 0; i < values.getSize(); ++i) {
244
        values[i].s0 = nextValue;
245
        nextValue *= (float) M_PI;
246
247
248
249
250
251
    }
    values.upload();
    accuracyKernel.setArg<cl::Buffer>(0, values.getDeviceBuffer());
    accuracyKernel.setArg<cl_int>(1, values.getSize());
    executeKernel(accuracyKernel, values.getSize());
    values.download();
252
    double maxSqrtError = 0.0, maxRsqrtError = 0.0, maxRecipError = 0.0, maxExpError = 0.0, maxLogError = 0.0;
253
    for (int i = 0; i < values.getSize(); ++i) {
254
255
256
257
258
259
260
        double v = values[i].s0;
        double correctSqrt = sqrt(v);
        maxSqrtError = max(maxSqrtError, fabs(correctSqrt-values[i].s1)/correctSqrt);
        maxRsqrtError = max(maxRsqrtError, fabs(1.0/correctSqrt-values[i].s2)*correctSqrt);
        maxRecipError = max(maxRecipError, fabs(1.0/v-values[i].s3)/values[i].s3);
        maxExpError = max(maxExpError, fabs(exp(v)-values[i].s4)/values[i].s4);
        maxLogError = max(maxLogError, fabs(log(v)-values[i].s5)/values[i].s5);
261
    }
262
263
264
265
266
    compilationDefines["SQRT"] = (maxSqrtError < 1e-6) ? "native_sqrt" : "sqrt";
    compilationDefines["RSQRT"] = (maxRsqrtError < 1e-6) ? "native_rsqrt" : "rsqrt";
    compilationDefines["RECIP"] = (maxRecipError < 1e-6) ? "native_recip" : "1.0f/";
    compilationDefines["EXP"] = (maxExpError < 1e-6) ? "native_exp" : "exp";
    compilationDefines["LOG"] = (maxLogError < 1e-6) ? "native_log" : "log";
267
268
269
270
    
    // Create the work thread used for parallelization when running on multiple devices.
    
    thread = new WorkThread();
271
272
273
}

OpenCLContext::~OpenCLContext() {
274
275
    for (int i = 0; i < (int) forces.size(); i++)
        delete forces[i];
276
277
    for (int i = 0; i < (int) reorderListeners.size(); i++)
        delete reorderListeners[i];
278
279
280
281
282
283
284
285
    if (posq != NULL)
        delete posq;
    if (velm != NULL)
        delete velm;
    if (force != NULL)
        delete force;
    if (forceBuffers != NULL)
        delete forceBuffers;
286
287
    if (longForceBuffer != NULL)
        delete longForceBuffer;
288
289
290
291
292
293
    if (energyBuffer != NULL)
        delete energyBuffer;
    if (atomIndex != NULL)
        delete atomIndex;
    if (integration != NULL)
        delete integration;
Peter Eastman's avatar
Peter Eastman committed
294
295
    if (bonded != NULL)
        delete bonded;
296
297
    if (nonbonded != NULL)
        delete nonbonded;
298
299
    if (thread != NULL)
        delete thread;
300
301
}

302
void OpenCLContext::initialize(const System& system) {
303
304
305
306
    for (int i = 0; i < numAtoms; i++) {
        double mass = system.getParticleMass(i);
        (*velm)[i].w = (float) (mass == 0.0 ? 0.0 : 1.0/mass);
    }
307
    velm->upload();
Peter Eastman's avatar
Peter Eastman committed
308
    bonded->initialize(system);
309
    numForceBuffers = platformData.contexts.size();
Peter Eastman's avatar
Peter Eastman committed
310
    numForceBuffers = std::max(numForceBuffers, bonded->getNumForceBuffers());
311
312
313
    for (int i = 0; i < (int) forces.size(); i++)
        numForceBuffers = std::max(numForceBuffers, forces[i]->getRequiredForceBuffers());
    forceBuffers = new OpenCLArray<mm_float4>(*this, paddedNumAtoms*numForceBuffers, "forceBuffers", false);
314
315
316
317
318
319
320
321
    if (supports64BitGlobalAtomics) {
        longForceBuffer = new OpenCLArray<cl_long>(*this, 3*paddedNumAtoms, "longForceBuffer", false);
        reduceForcesKernel.setArg<cl::Buffer>(0, longForceBuffer->getDeviceBuffer());
        reduceForcesKernel.setArg<cl::Buffer>(1, forceBuffers->getDeviceBuffer());
        reduceForcesKernel.setArg<cl_int>(2, paddedNumAtoms);
        reduceForcesKernel.setArg<cl_int>(3, numForceBuffers);
        addAutoclearBuffer(longForceBuffer->getDeviceBuffer(), longForceBuffer->getSize()*2);
    }
322
    addAutoclearBuffer(forceBuffers->getDeviceBuffer(), forceBuffers->getSize()*4);
323
    force = new OpenCLArray<mm_float4>(*this, &forceBuffers->getDeviceBuffer(), paddedNumAtoms, "force", true);
324
    energyBuffer = new OpenCLArray<cl_float>(*this, max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers()), "energyBuffer", true);
325
    addAutoclearBuffer(energyBuffer->getDeviceBuffer(), energyBuffer->getSize());
326
327
328
329
    atomIndex = new OpenCLArray<cl_int>(*this, paddedNumAtoms, "atomIndex", true);
    for (int i = 0; i < paddedNumAtoms; ++i)
        (*atomIndex)[i] = i;
    atomIndex->upload();
330
    findMoleculeGroups(system);
331
    integration = new OpenCLIntegrationUtilities(*this, system);
332
    nonbonded->initialize(system);
333
334
335
336
337
338
}

void OpenCLContext::addForce(OpenCLForceInfo* force) {
    forces.push_back(force);
}

339
340
341
342
343
344
345
346
347
348
349
350
string OpenCLContext::loadSourceFromFile(const string& filename) const {
    ifstream file((Platform::getDefaultPluginsDirectory()+"/opencl/"+filename).c_str());
    if (!file.is_open())
        throw OpenMMException("Unable to load kernel: "+filename);
    string kernel;
    string line;
    while (!file.eof()) {
        getline(file, line);
        kernel += line;
        kernel += '\n';
    }
    file.close();
351
352
353
354
355
356
357
358
359
    return kernel;
}

string OpenCLContext::loadSourceFromFile(const string& filename, const std::map<std::string, std::string>& replacements) const {
    return replaceStrings(loadSourceFromFile(filename), replacements);
}

string OpenCLContext::replaceStrings(const string& input, const std::map<std::string, std::string>& replacements) const {
    string result = input;
360
361
362
    for (map<string, string>::const_iterator iter = replacements.begin(); iter != replacements.end(); iter++) {
        int index = -1;
        do {
363
364
365
366
            index = result.find(iter->first);
            if (index != result.npos)
                result.replace(index, iter->first.size(), iter->second);
        } while (index != result.npos);
367
    }
368
    return result;
369
370
}

371
372
cl::Program OpenCLContext::createProgram(const string source, const char* optimizationFlags) {
    return createProgram(source, map<string, string>(), optimizationFlags);
373
374
}

375
cl::Program OpenCLContext::createProgram(const string source, const map<string, string>& defines, const char* optimizationFlags) {
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    string options = (optimizationFlags == NULL ? defaultOptimizationOptions : optimizationFlags);
    stringstream src;
    if (!options.empty())
        src << "// Compilation Options: " << options << endl << endl;
    for (map<string, string>::const_iterator iter = compilationDefines.begin(); iter != compilationDefines.end(); ++iter) {
        src << "#define " << iter->first;
        if (!iter->second.empty())
            src << " " << iter->second;
        src << endl;
    }
    if (!compilationDefines.empty())
        src << endl;
    for (map<string, string>::const_iterator iter = defines.begin(); iter != defines.end(); ++iter) {
        src << "#define " << iter->first;
        if (!iter->second.empty())
            src << " " << iter->second;
        src << endl;
    }
    if (!defines.empty())
        src << endl;
    src << source << endl;
    // Get length before using c_str() to avoid length() call invalidating the c_str() value.
    string src_string = src.str();
    ::size_t src_length = src_string.length();
    cl::Program::Sources sources(1, make_pair(src_string.c_str(), src_length));
401
402
    cl::Program program(context, sources);
    try {
403
        program.build(vector<cl::Device>(1, device), options.c_str());
404
405
406
407
408
409
    } catch (cl::Error err) {
        throw OpenMMException("Error compiling kernel: "+program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(device));
    }
    return program;
}

410
411
412
413
void OpenCLContext::executeKernel(cl::Kernel& kernel, int workUnits, int blockSize) {
    if (blockSize == -1)
        blockSize = ThreadBlockSize;
    int size = std::min((workUnits+blockSize-1)/blockSize, numThreadBlocks)*blockSize;
414
    try {
415
        queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(size), cl::NDRange(blockSize));
416
417
418
    }
    catch (cl::Error err) {
        stringstream str;
419
        str<<"Error invoking kernel "<<kernel.getInfo<CL_KERNEL_FUNCTION_NAME>()<<": "<<err.what()<<" ("<<err.err()<<")";
420
421
422
423
        throw OpenMMException(str.str());
    }
}

424
void OpenCLContext::clearBuffer(OpenCLArray<float>& array) {
425
    clearBuffer(array.getDeviceBuffer(), array.getSize());
426
427
}

428
void OpenCLContext::clearBuffer(OpenCLArray<mm_float4>& array) {
429
430
431
    clearBuffer(array.getDeviceBuffer(), array.getSize()*4);
}

432
433
void OpenCLContext::clearBuffer(cl::Memory& memory, int size) {
    clearBufferKernel.setArg<cl::Memory>(0, memory);
434
    clearBufferKernel.setArg<cl_int>(1, size);
435
    executeKernel(clearBufferKernel, size, 128);
436
437
}

438
439
440
441
442
443
444
445
void OpenCLContext::addAutoclearBuffer(cl::Memory& memory, int size) {
    autoclearBuffers.push_back(&memory);
    autoclearBufferSizes.push_back(size);
}

void OpenCLContext::clearAutoclearBuffers() {
    int base = 0;
    int total = autoclearBufferSizes.size();
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    while (total-base >= 6) {
        clearSixBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearSixBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearSixBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearSixBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearSixBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearSixBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearSixBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearSixBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearSixBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearSixBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        clearSixBuffersKernel.setArg<cl::Memory>(10, *autoclearBuffers[base+5]);
        clearSixBuffersKernel.setArg<cl_int>(11, autoclearBufferSizes[base+5]);
        executeKernel(clearSixBuffersKernel, max(max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), autoclearBufferSizes[base+5]), 128);
        base += 6;
    }
    if (total-base == 5) {
        clearFiveBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFiveBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFiveBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFiveBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFiveBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFiveBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFiveBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFiveBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
        clearFiveBuffersKernel.setArg<cl::Memory>(8, *autoclearBuffers[base+4]);
        clearFiveBuffersKernel.setArg<cl_int>(9, autoclearBufferSizes[base+4]);
        executeKernel(clearFiveBuffersKernel, max(max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), autoclearBufferSizes[base+4]), 128);
    }
    else if (total-base == 4) {
476
477
478
479
480
481
482
483
        clearFourBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearFourBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearFourBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearFourBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearFourBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearFourBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
        clearFourBuffersKernel.setArg<cl::Memory>(6, *autoclearBuffers[base+3]);
        clearFourBuffersKernel.setArg<cl_int>(7, autoclearBufferSizes[base+3]);
484
        executeKernel(clearFourBuffersKernel, max(max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), autoclearBufferSizes[base+3]), 128);
485
    }
486
    else if (total-base == 3) {
487
488
489
490
491
492
        clearThreeBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearThreeBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearThreeBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearThreeBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
        clearThreeBuffersKernel.setArg<cl::Memory>(4, *autoclearBuffers[base+2]);
        clearThreeBuffersKernel.setArg<cl_int>(5, autoclearBufferSizes[base+2]);
493
        executeKernel(clearThreeBuffersKernel, max(max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), autoclearBufferSizes[base+2]), 128);
494
495
496
497
498
499
    }
    else if (total-base == 2) {
        clearTwoBuffersKernel.setArg<cl::Memory>(0, *autoclearBuffers[base]);
        clearTwoBuffersKernel.setArg<cl_int>(1, autoclearBufferSizes[base]);
        clearTwoBuffersKernel.setArg<cl::Memory>(2, *autoclearBuffers[base+1]);
        clearTwoBuffersKernel.setArg<cl_int>(3, autoclearBufferSizes[base+1]);
500
        executeKernel(clearTwoBuffersKernel, max(autoclearBufferSizes[base], autoclearBufferSizes[base+1]), 128);
501
502
503
504
505
506
    }
    else if (total-base == 1) {
        clearBuffer(*autoclearBuffers[base], autoclearBufferSizes[base]);
    }
}

507
508
509
510
511
512
513
void OpenCLContext::reduceForces() {
    if (supports64BitGlobalAtomics)
        executeKernel(reduceForcesKernel, paddedNumAtoms, 128);
    else
        reduceBuffer(*forceBuffers, numForceBuffers);
}

514
515
516
517
518
void OpenCLContext::reduceBuffer(OpenCLArray<mm_float4>& array, int numBuffers) {
    int bufferSize = array.getSize()/numBuffers;
    reduceFloat4Kernel.setArg<cl::Buffer>(0, array.getDeviceBuffer());
    reduceFloat4Kernel.setArg<cl_int>(1, bufferSize);
    reduceFloat4Kernel.setArg<cl_int>(2, numBuffers);
519
    executeKernel(reduceFloat4Kernel, bufferSize, 128);
520
}
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536

void OpenCLContext::tagAtomsInMolecule(int atom, int molecule, vector<int>& atomMolecule, vector<vector<int> >& atomBonds) {
    // Recursively tag atoms as belonging to a particular molecule.

    atomMolecule[atom] = molecule;
    for (int i = 0; i < (int) atomBonds[atom].size(); i++)
        if (atomMolecule[atomBonds[atom][i]] == -1)
            tagAtomsInMolecule(atomBonds[atom][i], molecule, atomMolecule, atomBonds);
}

struct OpenCLContext::Molecule {
    vector<int> atoms;
    vector<int> constraints;
    vector<vector<int> > groups;
};

537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
/**
 * This class ensures that atom reordering doesn't break virtual sites.
 */
class OpenCLContext::VirtualSiteInfo : public OpenCLForceInfo {
public:
    VirtualSiteInfo(const System& system) : OpenCLForceInfo(0) {
        for (int i = 0; i < system.getNumParticles(); i++) {
            if (system.isVirtualSite(i)) {
                siteTypes.push_back(&typeid(system.getVirtualSite(i)));
                vector<int> particles;
                particles.push_back(i);
                for (int j = 0; j < system.getVirtualSite(i).getNumParticles(); j++)
                    particles.push_back(system.getVirtualSite(i).getParticle(j));
                siteParticles.push_back(particles);
                vector<double> weights;
                if (dynamic_cast<const TwoParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) {
                    // A two particle average.

                    const TwoParticleAverageSite& site = dynamic_cast<const TwoParticleAverageSite&>(system.getVirtualSite(i));
                    weights.push_back(site.getWeight(0));
                    weights.push_back(site.getWeight(1));
                }
                else if (dynamic_cast<const ThreeParticleAverageSite*>(&system.getVirtualSite(i)) != NULL) {
                    // A three particle average.

                    const ThreeParticleAverageSite& site = dynamic_cast<const ThreeParticleAverageSite&>(system.getVirtualSite(i));
                    weights.push_back(site.getWeight(0));
                    weights.push_back(site.getWeight(1));
                    weights.push_back(site.getWeight(2));
                }
                else if (dynamic_cast<const OutOfPlaneSite*>(&system.getVirtualSite(i)) != NULL) {
                    // An out of plane site.

                    const OutOfPlaneSite& site = dynamic_cast<const OutOfPlaneSite&>(system.getVirtualSite(i));
                    weights.push_back(site.getWeight12());
                    weights.push_back(site.getWeight13());
                    weights.push_back(site.getWeightCross());
                }
                siteWeights.push_back(weights);
            }
        }
    }
    int getNumParticleGroups() {
        return siteTypes.size();
    }
    void getParticlesInGroup(int index, std::vector<int>& particles) {
        particles = siteParticles[index];
    }
    bool areGroupsIdentical(int group1, int group2) {
        if (siteTypes[group1] != siteTypes[group2])
            return false;
        int numParticles = siteWeights[group1].size();
        if (siteWeights[group2].size() != numParticles)
            return false;
        for (int i = 0; i < numParticles; i++)
            if (siteWeights[group1][i] != siteWeights[group2][i])
                return false;
        return true;
    }
private:
    vector<const type_info*> siteTypes;
    vector<vector<int> > siteParticles;
    vector<vector<double> > siteWeights;
};


603
void OpenCLContext::findMoleculeGroups(const System& system) {
604
605
606
607
    // Add a ForceInfo that makes sure reordering doesn't break virtual sites.
    
    addForce(new VirtualSiteInfo(system));
    
608
609
610
611
612
613
614
615
616
617
    // First make a list of every other atom to which each atom is connect by a constraint or force group.

    vector<vector<int> > atomBonds(system.getNumParticles());
    for (int i = 0; i < system.getNumConstraints(); i++) {
        int particle1, particle2;
        double distance;
        system.getConstraintParameters(i, particle1, particle2, distance);
        atomBonds[particle1].push_back(particle2);
        atomBonds[particle2].push_back(particle1);
    }
618
    for (int i = 0; i < (int) forces.size(); i++) {
619
620
621
        for (int j = 0; j < forces[i]->getNumParticleGroups(); j++) {
            vector<int> particles;
            forces[i]->getParticlesInGroup(j, particles);
622
623
            for (int k = 0; k < (int) particles.size(); k++)
                for (int m = 0; m < (int) particles.size(); m++)
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
                    if (k != m)
                        atomBonds[particles[k]].push_back(particles[m]);
        }
    }

    // Now tag atoms by which molecule they belong to.

    vector<int> atomMolecule(numAtoms, -1);
    int numMolecules = 0;
    for (int i = 0; i < numAtoms; i++)
        if (atomMolecule[i] == -1)
            tagAtomsInMolecule(i, numMolecules++, atomMolecule, atomBonds);
    vector<vector<int> > atomIndices(numMolecules);
    for (int i = 0; i < numAtoms; i++)
        atomIndices[atomMolecule[i]].push_back(i);

    // Construct a description of each molecule.

    vector<Molecule> molecules(numMolecules);
    for (int i = 0; i < numMolecules; i++) {
        molecules[i].atoms = atomIndices[i];
        molecules[i].groups.resize(forces.size());
    }
    for (int i = 0; i < system.getNumConstraints(); i++) {
        int particle1, particle2;
        double distance;
        system.getConstraintParameters(i, particle1, particle2, distance);
        molecules[atomMolecule[particle1]].constraints.push_back(i);
    }
653
    for (int i = 0; i < (int) forces.size(); i++)
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
        for (int j = 0; j < forces[i]->getNumParticleGroups(); j++) {
            vector<int> particles;
            forces[i]->getParticlesInGroup(j, particles);
            molecules[atomMolecule[particles[0]]].groups[i].push_back(j);
        }

    // Sort them into groups of identical molecules.

    vector<Molecule> uniqueMolecules;
    vector<vector<int> > moleculeInstances;
    for (int molIndex = 0; molIndex < (int) molecules.size(); molIndex++) {
        Molecule& mol = molecules[molIndex];

        // See if it is identical to another molecule.

        bool isNew = true;
        for (int j = 0; j < (int) uniqueMolecules.size() && isNew; j++) {
            Molecule& mol2 = uniqueMolecules[j];
            bool identical = (mol.atoms.size() == mol2.atoms.size() && mol.constraints.size() == mol2.constraints.size());

            // See if the atoms are identical.

            int atomOffset = mol2.atoms[0]-mol.atoms[0];
            for (int i = 0; i < (int) mol.atoms.size() && identical; i++) {
                if (mol.atoms[i] != mol2.atoms[i]-atomOffset || system.getParticleMass(mol.atoms[i]) != system.getParticleMass(mol2.atoms[i]))
                    identical = false;
680
                for (int k = 0; k < (int) forces.size(); k++)
681
682
683
684
685
686
687
688
689
690
691
                    if (!forces[k]->areParticlesIdentical(mol.atoms[i], mol2.atoms[i]))
                        identical = false;
            }
            
            // See if the constraints are identical.

            for (int i = 0; i < (int) mol.constraints.size() && identical; i++) {
                int c1particle1, c1particle2, c2particle1, c2particle2;
                double distance1, distance2;
                system.getConstraintParameters(mol.constraints[i], c1particle1, c1particle2, distance1);
                system.getConstraintParameters(mol2.constraints[i], c2particle1, c2particle2, distance2);
692
                if (c1particle1 != c2particle1-atomOffset || c1particle2 != c2particle2-atomOffset || distance1 != distance2)
693
694
695
696
697
                    identical = false;
            }

            // See if the force groups are identical.

698
            for (int i = 0; i < (int) forces.size() && identical; i++) {
699
700
                if (mol.groups[i].size() != mol2.groups[i].size())
                    identical = false;
701
                for (int k = 0; k < (int) mol.groups[i].size() && identical; k++)
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
                    if (!forces[i]->areGroupsIdentical(mol.groups[i][k], mol2.groups[i][k]))
                        identical = false;
            }
            if (identical) {
                moleculeInstances[j].push_back(mol.atoms[0]);
                isNew = false;
            }
        }
        if (isNew) {
            uniqueMolecules.push_back(mol);
            moleculeInstances.push_back(vector<int>());
            moleculeInstances[moleculeInstances.size()-1].push_back(mol.atoms[0]);
        }
    }
    moleculeGroups.resize(moleculeInstances.size());
    for (int i = 0; i < (int) moleculeInstances.size(); i++)
    {
        moleculeGroups[i].instances = moleculeInstances[i];
        vector<int>& atoms = uniqueMolecules[i].atoms;
        moleculeGroups[i].atoms.resize(atoms.size());
        for (int j = 0; j < (int) atoms.size(); j++)
            moleculeGroups[i].atoms[j] = atoms[j]-atoms[0];
    }
}

void OpenCLContext::reorderAtoms() {
    if (numAtoms == 0 || nonbonded == NULL || !nonbonded->getUseCutoff())
        return;
Peter Eastman's avatar
Peter Eastman committed
730
    atomsWereReordered = true;
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746

    // Find the range of positions and the number of bins along each axis.

    posq->download();
    velm->download();
    float minx = posq->get(0).x, maxx = posq->get(0).x;
    float miny = posq->get(0).y, maxy = posq->get(0).y;
    float minz = posq->get(0).z, maxz = posq->get(0).z;
    if (nonbonded->getUsePeriodic()) {
        minx = miny = minz = 0.0;
        maxx = periodicBoxSize.x;
        maxy = periodicBoxSize.y;
        maxz = periodicBoxSize.z;
    }
    else {
        for (int i = 1; i < numAtoms; i++) {
747
748
749
750
751
752
753
            const mm_float4& pos = posq->get(i);
            minx = min(minx, pos.x);
            maxx = max(maxx, pos.x);
            miny = min(miny, pos.y);
            maxy = max(maxy, pos.y);
            minz = min(minz, pos.z);
            maxz = max(maxz, pos.z);
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
        }
    }

    // Loop over each group of identical molecules and reorder them.

    vector<int> originalIndex(numAtoms);
    vector<mm_float4> newPosq(numAtoms);
    vector<mm_float4> newVelm(numAtoms);
    vector<mm_int4> newCellOffsets(numAtoms);
    for (int group = 0; group < (int) moleculeGroups.size(); group++) {
        // Find the center of each molecule.

        MoleculeGroup& mol = moleculeGroups[group];
        int numMolecules = mol.instances.size();
        vector<int>& atoms = mol.atoms;
        vector<mm_float4> molPos(numMolecules);
770
        float invNumAtoms = 1.0f/atoms.size();
771
772
773
774
775
776
        for (int i = 0; i < numMolecules; i++) {
            molPos[i].x = 0.0f;
            molPos[i].y = 0.0f;
            molPos[i].z = 0.0f;
            for (int j = 0; j < (int)atoms.size(); j++) {
                int atom = atoms[j]+mol.instances[i];
777
778
779
780
                const mm_float4& pos = posq->get(atom);
                molPos[i].x += pos.x;
                molPos[i].y += pos.y;
                molPos[i].z += pos.z;
781
            }
782
783
784
            molPos[i].x *= invNumAtoms;
            molPos[i].y *= invNumAtoms;
            molPos[i].z *= invNumAtoms;
785
786
787
788
789
        }
        if (nonbonded->getUsePeriodic()) {
            // Move each molecule position into the same box.

            for (int i = 0; i < numMolecules; i++) {
790
791
792
                int xcell = (int) floor(molPos[i].x*invPeriodicBoxSize.x);
                int ycell = (int) floor(molPos[i].y*invPeriodicBoxSize.y);
                int zcell = (int) floor(molPos[i].z*invPeriodicBoxSize.z);
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
                float dx = xcell*periodicBoxSize.x;
                float dy = ycell*periodicBoxSize.y;
                float dz = zcell*periodicBoxSize.z;
                if (dx != 0.0f || dy != 0.0f || dz != 0.0f) {
                    molPos[i].x -= dx;
                    molPos[i].y -= dy;
                    molPos[i].z -= dz;
                    for (int j = 0; j < (int) atoms.size(); j++) {
                        int atom = atoms[j]+mol.instances[i];
                        mm_float4 p = posq->get(atom);
                        p.x -= dx;
                        p.y -= dy;
                        p.z -= dz;
                        posq->set(atom, p);
                        posCellOffsets[atom].x -= xcell;
                        posCellOffsets[atom].y -= ycell;
                        posCellOffsets[atom].z -= zcell;
                    }
                }
            }
        }

        // Select a bin for each molecule, then sort them by bin.

        bool useHilbert = (numMolecules > 5000 || atoms.size() > 8); // For small systems, a simple zigzag curve works better than a Hilbert curve.
        float binWidth;
        if (useHilbert)
            binWidth = (float)(max(max(maxx-minx, maxy-miny), maxz-minz)/255.0);
        else
            binWidth = (float)(0.2*nonbonded->getCutoffDistance());
823
824
825
        float invBinWidth = 1.0f/binWidth;
        int xbins = 1 + (int) ((maxx-minx)*invBinWidth);
        int ybins = 1 + (int) ((maxy-miny)*invBinWidth);
826
827
828
        vector<pair<int, int> > molBins(numMolecules);
        bitmask_t coords[3];
        for (int i = 0; i < numMolecules; i++) {
829
830
831
            int x = (int) ((molPos[i].x-minx)*invBinWidth);
            int y = (int) ((molPos[i].y-miny)*invBinWidth);
            int z = (int) ((molPos[i].z-minz)*invBinWidth);
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
            int bin;
            if (useHilbert) {
                coords[0] = x;
                coords[1] = y;
                coords[2] = z;
                bin = (int) hilbert_c2i(3, 8, coords);
            }
            else {
                int yodd = y&1;
                int zodd = z&1;
                bin = z*xbins*ybins;
                bin += (zodd ? ybins-y : y)*xbins;
                bin += (yodd ? xbins-x : x);
            }
            molBins[i] = pair<int, int>(bin, i);
        }
        sort(molBins.begin(), molBins.end());

        // Reorder the atoms.

        for (int i = 0; i < numMolecules; i++) {
            for (int j = 0; j < (int)atoms.size(); j++) {
                int oldIndex = mol.instances[molBins[i].second]+atoms[j];
                int newIndex = mol.instances[i]+atoms[j];
                originalIndex[newIndex] = atomIndex->get(oldIndex);
                newPosq[newIndex] = posq->get(oldIndex);
                newVelm[newIndex] = velm->get(oldIndex);
                newCellOffsets[newIndex] = posCellOffsets[oldIndex];
            }
        }
    }

    // Update the streams.

    for (int i = 0; i < numAtoms; i++) {
        posq->set(i, newPosq[i]);
        velm->set(i, newVelm[i]);
        atomIndex->set(i, originalIndex[i]);
        posCellOffsets[i] = newCellOffsets[i];
    }
    posq->upload();
    velm->upload();
    atomIndex->upload();
875
876
877
878
879
880
    for (int i = 0; i < (int) reorderListeners.size(); i++)
        reorderListeners[i]->execute();
}

void OpenCLContext::addReorderListener(ReorderListener* listener) {
    reorderListeners.push_back(listener);
881
}
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964

struct OpenCLContext::WorkThread::ThreadData {
    ThreadData(std::queue<OpenCLContext::WorkTask*>& tasks, bool& waiting,  bool& finished,
            pthread_mutex_t& queueLock, pthread_cond_t& waitForTaskCondition, pthread_cond_t& queueEmptyCondition) :
        tasks(tasks), waiting(waiting), finished(finished), queueLock(queueLock),
        waitForTaskCondition(waitForTaskCondition), queueEmptyCondition(queueEmptyCondition) {
    }
    std::queue<OpenCLContext::WorkTask*>& tasks;
    bool& waiting;
    bool& finished;
    pthread_mutex_t& queueLock;
    pthread_cond_t& waitForTaskCondition;
    pthread_cond_t& queueEmptyCondition;
};

static void* threadBody(void* args) {
    OpenCLContext::WorkThread::ThreadData& data = *reinterpret_cast<OpenCLContext::WorkThread::ThreadData*>(args);
    while (!data.finished || data.tasks.size() > 0) {
        pthread_mutex_lock(&data.queueLock);
        while (data.tasks.empty() && !data.finished) {
            data.waiting = true;
            pthread_cond_signal(&data.queueEmptyCondition);
            pthread_cond_wait(&data.waitForTaskCondition, &data.queueLock);
        }
        OpenCLContext::WorkTask* task = NULL;
        if (!data.tasks.empty()) {
            data.waiting = false;
            task = data.tasks.front();
            data.tasks.pop();
        }
        pthread_mutex_unlock(&data.queueLock);
        if (task != NULL) {
            task->execute();
            delete task;
        }
    }
    data.waiting = true;
    pthread_cond_signal(&data.queueEmptyCondition);
    delete &data;
    return 0;
}

OpenCLContext::WorkThread::WorkThread() : waiting(true), finished(false) {
    pthread_mutex_init(&queueLock, NULL);
    pthread_cond_init(&waitForTaskCondition, NULL);
    pthread_cond_init(&queueEmptyCondition, NULL);
    ThreadData* data = new ThreadData(tasks, waiting, finished, queueLock, waitForTaskCondition, queueEmptyCondition);
    pthread_create(&thread, NULL, threadBody, data);
}

OpenCLContext::WorkThread::~WorkThread() {
    pthread_mutex_lock(&queueLock);
    finished = true;
    pthread_cond_broadcast(&waitForTaskCondition);
    pthread_mutex_unlock(&queueLock);
    pthread_join(thread, NULL);
    pthread_mutex_destroy(&queueLock);
    pthread_cond_destroy(&waitForTaskCondition);
    pthread_cond_destroy(&queueEmptyCondition);
}

void OpenCLContext::WorkThread::addTask(OpenCLContext::WorkTask* task) {
    pthread_mutex_lock(&queueLock);
    tasks.push(task);
    waiting = false;
    pthread_cond_signal(&waitForTaskCondition);
    pthread_mutex_unlock(&queueLock);
}

bool OpenCLContext::WorkThread::isWaiting() {
    return waiting;
}

bool OpenCLContext::WorkThread::isFinished() {
    return finished;
}

void OpenCLContext::WorkThread::flush() {
    pthread_mutex_lock(&queueLock);
    while (!waiting)
       pthread_cond_wait(&queueEmptyCondition, &queueLock);
    pthread_mutex_unlock(&queueLock);
}