Commit 4c0621ed authored by Peter Eastman's avatar Peter Eastman
Browse files

Minor optimization to sorting

parent ad5821ff
......@@ -112,13 +112,12 @@ void CudaSort::sort(CudaArray& data) {
else {
// Compute the range of data values.
void* rangeArgs[] = {&data.getDevicePointer(), &dataLength, &dataRange->getDevicePointer()};
unsigned int numBuckets = bucketOffset->getSize();
void* rangeArgs[] = {&data.getDevicePointer(), &dataLength, &dataRange->getDevicePointer(), &numBuckets, &bucketOffset->getDevicePointer()};
context.executeKernel(computeRangeKernel, rangeArgs, rangeKernelSize, rangeKernelSize, rangeKernelSize*trait->getKeySize());
// Assign array elements to buckets.
unsigned int numBuckets = bucketOffset->getSize();
context.clearBuffer(*bucketOffset);
void* elementsArgs[] = {&data.getDevicePointer(), &dataLength, &numBuckets, &dataRange->getDevicePointer(),
&bucketOffset->getDevicePointer(), &bucketOfElement->getDevicePointer(), &offsetInBucket->getDevicePointer()};
context.executeKernel(assignElementsKernel, elementsArgs, data.getSize());
......
......@@ -50,7 +50,8 @@ __global__ void sortShortList(DATA_TYPE* __restrict__ data, unsigned int length)
* Calculate the minimum and maximum value in the array to be sorted. This kernel
* is executed as a single work group.
*/
__global__ void computeRange(const DATA_TYPE* __restrict__ data, unsigned int length, KEY_TYPE* __restrict__ range) {
__global__ void computeRange(const DATA_TYPE* __restrict__ data, unsigned int length, KEY_TYPE* __restrict__ range,
unsigned int numBuckets, unsigned int* __restrict__ bucketOffset) {
extern __shared__ KEY_TYPE rangeBuffer[];
KEY_TYPE minimum = MAX_KEY;
KEY_TYPE maximum = MIN_KEY;
......@@ -86,6 +87,11 @@ __global__ void computeRange(const DATA_TYPE* __restrict__ data, unsigned int le
range[0] = minimum;
range[1] = maximum;
}
// Clear the bucket counters in preparation for the next kernel.
for (unsigned int index = threadIdx.x; index < numBuckets; index += blockDim.x)
bucketOffset[index] = 0;
}
/**
......
......@@ -116,16 +116,17 @@ void OpenCLSort::sort(OpenCLArray& data) {
else {
// Compute the range of data values.
unsigned int numBuckets = bucketOffset->getSize();
computeRangeKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
computeRangeKernel.setArg<cl_uint>(1, data.getSize());
computeRangeKernel.setArg<cl::Buffer>(2, dataRange->getDeviceBuffer());
computeRangeKernel.setArg(3, rangeKernelSize*trait->getKeySize(), NULL);
computeRangeKernel.setArg<cl_int>(4, numBuckets);
computeRangeKernel.setArg<cl::Buffer>(5, bucketOffset->getDeviceBuffer());
context.executeKernel(computeRangeKernel, rangeKernelSize, rangeKernelSize);
// Assign array elements to buckets.
unsigned int numBuckets = bucketOffset->getSize();
context.clearBuffer(*bucketOffset);
assignElementsKernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
assignElementsKernel.setArg<cl_int>(1, data.getSize());
assignElementsKernel.setArg<cl_int>(2, numBuckets);
......
......@@ -49,7 +49,8 @@ __kernel void sortShortList(__global DATA_TYPE* __restrict__ data, uint length,
* Calculate the minimum and maximum value in the array to be sorted. This kernel
* is executed as a single work group.
*/
__kernel void computeRange(__global const DATA_TYPE* restrict data, uint length, __global KEY_TYPE* restrict range, __local KEY_TYPE* restrict buffer) {
__kernel void computeRange(__global const DATA_TYPE* restrict data, uint length, __global KEY_TYPE* restrict range, __local KEY_TYPE* restrict buffer,
uint numBuckets, __global uint* restrict bucketOffset) {
KEY_TYPE minimum = MAX_KEY;
KEY_TYPE maximum = MIN_KEY;
......@@ -84,6 +85,11 @@ __kernel void computeRange(__global const DATA_TYPE* restrict data, uint length,
range[0] = minimum;
range[1] = maximum;
}
// Clear the bucket counters in preparation for the next kernel.
for (uint index = get_local_id(0); index < numBuckets; index += get_local_size(0))
bucketOffset[index] = 0;
}
/**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment