Commit 0297d2f0 authored by peastman's avatar peastman
Browse files

Fixed errors in sorting on some AMD GPUs

parent 7c7dc751
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2010-2018 Stanford University and the Authors. * * Portions copyright (c) 2010-2020 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -64,10 +64,11 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le ...@@ -64,10 +64,11 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le
unsigned int maxPositionsSize = std::min(maxGroupSize, (unsigned int) computeBucketPositionsKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice())); unsigned int maxPositionsSize = std::min(maxGroupSize, (unsigned int) computeBucketPositionsKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()));
int maxLocalBuffer = (maxSharedMem/trait->getDataSize())/2; int maxLocalBuffer = (maxSharedMem/trait->getDataSize())/2;
unsigned int maxShortList = min(8192, max(maxLocalBuffer, (int) OpenCLContext::ThreadBlockSize*context.getNumThreadBlocks())); unsigned int maxShortList = min(8192, max(maxLocalBuffer, (int) OpenCLContext::ThreadBlockSize*context.getNumThreadBlocks()));
// On Qualcomm's OpenCL, it's essential to check against CL_KERNEL_WORK_GROUP_SIZE. Otherwise you get a crash. // The following line checks CL_KERNEL_WORK_GROUP_SIZE to make sure we don't create too large a workgroup.
// But AMD's OpenCL returns an inappropriately small value for it that is much shorter than the actual // Unfortunately, AMD's OpenCL returns an inappropriately small value for it that is much shorter than the actual
// maximum, so including the check hurts performance. For the moment I'm going to just comment it out. // maximum, so including the check hurts performance. For the moment I'm just leaving it commented out.
// If we officially support Qualcomm in the future, we'll need to do something better. // If the workgroup size turns out to be too large, we catch the exception and switch back to the standard
// sorting kernels.
//maxShortList = min(maxShortList, shortListKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice())); //maxShortList = min(maxShortList, shortListKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()));
isShortList = (length <= maxShortList); isShortList = (length <= maxShortList);
string vendor = context.getDevice().getInfo<CL_DEVICE_VENDOR>(); string vendor = context.getDevice().getInfo<CL_DEVICE_VENDOR>();
...@@ -92,12 +93,10 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le ...@@ -92,12 +93,10 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le
// Create workspace arrays. // Create workspace arrays.
if (!isShortList) { dataRange.initialize(context, 2, trait->getKeySize(), "sortDataRange");
dataRange.initialize(context, 2, trait->getKeySize(), "sortDataRange"); bucketOffset.initialize<cl_uint>(context, numBuckets, "bucketOffset");
bucketOffset.initialize<cl_uint>(context, numBuckets, "bucketOffset"); bucketOfElement.initialize<cl_uint>(context, length, "bucketOfElement");
bucketOfElement.initialize<cl_uint>(context, length, "bucketOfElement"); offsetInBucket.initialize<cl_uint>(context, length, "offsetInBucket");
offsetInBucket.initialize<cl_uint>(context, length, "offsetInBucket");
}
buckets.initialize(context, length, trait->getDataSize(), "buckets"); buckets.initialize(context, length, trait->getDataSize(), "buckets");
} }
...@@ -113,68 +112,76 @@ void OpenCLSort::sort(OpenCLArray& data) { ...@@ -113,68 +112,76 @@ void OpenCLSort::sort(OpenCLArray& data) {
if (isShortList) { if (isShortList) {
// We can use a simpler sort kernel that does the entire operation in one kernel. // We can use a simpler sort kernel that does the entire operation in one kernel.
if (useShortList2) { try {
shortList2Kernel.setArg<cl::Buffer>(0, data.getDeviceBuffer()); if (useShortList2) {
shortList2Kernel.setArg<cl::Buffer>(1, buckets.getDeviceBuffer()); shortList2Kernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
shortList2Kernel.setArg<cl_int>(2, dataLength); shortList2Kernel.setArg<cl::Buffer>(1, buckets.getDeviceBuffer());
context.executeKernel(shortList2Kernel, dataLength); shortList2Kernel.setArg<cl_int>(2, dataLength);
buckets.copyTo(data); context.executeKernel(shortList2Kernel, dataLength);
buckets.copyTo(data);
}
else {
shortListKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
shortListKernel.setArg<cl_uint>(1, dataLength);
shortListKernel.setArg(2, dataLength*trait->getDataSize(), NULL);
context.executeKernel(shortListKernel, sortKernelSize, sortKernelSize);
}
return;
} }
else { catch (exception& ex) {
shortListKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer()); // This can happen if we chose too large a size for the kernel. Switch
shortListKernel.setArg<cl_uint>(1, dataLength); // over to the standard sorting method.
shortListKernel.setArg(2, dataLength*trait->getDataSize(), NULL);
context.executeKernel(shortListKernel, sortKernelSize, sortKernelSize); isShortList = false;
} }
} }
else {
// Compute the range of data values. // Compute the range of data values.
unsigned int numBuckets = bucketOffset.getSize(); unsigned int numBuckets = bucketOffset.getSize();
computeRangeKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer()); computeRangeKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
computeRangeKernel.setArg<cl_uint>(1, data.getSize()); computeRangeKernel.setArg<cl_uint>(1, data.getSize());
computeRangeKernel.setArg<cl::Buffer>(2, dataRange.getDeviceBuffer()); computeRangeKernel.setArg<cl::Buffer>(2, dataRange.getDeviceBuffer());
computeRangeKernel.setArg(3, rangeKernelSize*trait->getKeySize(), NULL); computeRangeKernel.setArg(3, rangeKernelSize*trait->getKeySize(), NULL);
computeRangeKernel.setArg(4, rangeKernelSize*trait->getKeySize(), NULL); computeRangeKernel.setArg(4, rangeKernelSize*trait->getKeySize(), NULL);
computeRangeKernel.setArg<cl_int>(5, numBuckets); computeRangeKernel.setArg<cl_int>(5, numBuckets);
computeRangeKernel.setArg<cl::Buffer>(6, bucketOffset.getDeviceBuffer()); computeRangeKernel.setArg<cl::Buffer>(6, bucketOffset.getDeviceBuffer());
context.executeKernel(computeRangeKernel, rangeKernelSize, rangeKernelSize); context.executeKernel(computeRangeKernel, rangeKernelSize, rangeKernelSize);
// Assign array elements to buckets. // Assign array elements to buckets.
assignElementsKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer()); assignElementsKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
assignElementsKernel.setArg<cl_int>(1, data.getSize()); assignElementsKernel.setArg<cl_int>(1, data.getSize());
assignElementsKernel.setArg<cl_int>(2, numBuckets); assignElementsKernel.setArg<cl_int>(2, numBuckets);
assignElementsKernel.setArg<cl::Buffer>(3, dataRange.getDeviceBuffer()); assignElementsKernel.setArg<cl::Buffer>(3, dataRange.getDeviceBuffer());
assignElementsKernel.setArg<cl::Buffer>(4, bucketOffset.getDeviceBuffer()); assignElementsKernel.setArg<cl::Buffer>(4, bucketOffset.getDeviceBuffer());
assignElementsKernel.setArg<cl::Buffer>(5, bucketOfElement.getDeviceBuffer()); assignElementsKernel.setArg<cl::Buffer>(5, bucketOfElement.getDeviceBuffer());
assignElementsKernel.setArg<cl::Buffer>(6, offsetInBucket.getDeviceBuffer()); assignElementsKernel.setArg<cl::Buffer>(6, offsetInBucket.getDeviceBuffer());
context.executeKernel(assignElementsKernel, data.getSize()); context.executeKernel(assignElementsKernel, data.getSize());
// Compute the position of each bucket. // Compute the position of each bucket.
computeBucketPositionsKernel.setArg<cl_int>(0, numBuckets); computeBucketPositionsKernel.setArg<cl_int>(0, numBuckets);
computeBucketPositionsKernel.setArg<cl::Buffer>(1, bucketOffset.getDeviceBuffer()); computeBucketPositionsKernel.setArg<cl::Buffer>(1, bucketOffset.getDeviceBuffer());
computeBucketPositionsKernel.setArg(2, positionsKernelSize*sizeof(cl_int), NULL); computeBucketPositionsKernel.setArg(2, positionsKernelSize*sizeof(cl_int), NULL);
context.executeKernel(computeBucketPositionsKernel, positionsKernelSize, positionsKernelSize); context.executeKernel(computeBucketPositionsKernel, positionsKernelSize, positionsKernelSize);
// Copy the data into the buckets. // Copy the data into the buckets.
copyToBucketsKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer()); copyToBucketsKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
copyToBucketsKernel.setArg<cl::Buffer>(1, buckets.getDeviceBuffer()); copyToBucketsKernel.setArg<cl::Buffer>(1, buckets.getDeviceBuffer());
copyToBucketsKernel.setArg<cl_int>(2, data.getSize()); copyToBucketsKernel.setArg<cl_int>(2, data.getSize());
copyToBucketsKernel.setArg<cl::Buffer>(3, bucketOffset.getDeviceBuffer()); copyToBucketsKernel.setArg<cl::Buffer>(3, bucketOffset.getDeviceBuffer());
copyToBucketsKernel.setArg<cl::Buffer>(4, bucketOfElement.getDeviceBuffer()); copyToBucketsKernel.setArg<cl::Buffer>(4, bucketOfElement.getDeviceBuffer());
copyToBucketsKernel.setArg<cl::Buffer>(5, offsetInBucket.getDeviceBuffer()); copyToBucketsKernel.setArg<cl::Buffer>(5, offsetInBucket.getDeviceBuffer());
context.executeKernel(copyToBucketsKernel, data.getSize()); context.executeKernel(copyToBucketsKernel, data.getSize());
// Sort each bucket. // Sort each bucket.
sortBucketsKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer()); sortBucketsKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
sortBucketsKernel.setArg<cl::Buffer>(1, buckets.getDeviceBuffer()); sortBucketsKernel.setArg<cl::Buffer>(1, buckets.getDeviceBuffer());
sortBucketsKernel.setArg<cl_int>(2, numBuckets); sortBucketsKernel.setArg<cl_int>(2, numBuckets);
sortBucketsKernel.setArg<cl::Buffer>(3, bucketOffset.getDeviceBuffer()); sortBucketsKernel.setArg<cl::Buffer>(3, bucketOffset.getDeviceBuffer());
sortBucketsKernel.setArg(4, sortKernelSize*trait->getDataSize(), NULL); sortBucketsKernel.setArg(4, sortKernelSize*trait->getDataSize(), NULL);
context.executeKernel(sortBucketsKernel, ((data.getSize()+sortKernelSize-1)/sortKernelSize)*sortKernelSize, sortKernelSize); context.executeKernel(sortBucketsKernel, ((data.getSize()+sortKernelSize-1)/sortKernelSize)*sortKernelSize, sortKernelSize);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment