Commit 35b9d787 authored by peastman's avatar peastman
Browse files

Workarounds for devices whose maximum workgroup size varies between kernels

parent d59a34fb
This diff is collapsed.
...@@ -317,42 +317,54 @@ void OpenCLNonbondedUtilities::initialize(const System& system) { ...@@ -317,42 +317,54 @@ void OpenCLNonbondedUtilities::initialize(const System& system) {
for (int i = 0; i < (int) exclusionBlocksForBlock.size(); i++) for (int i = 0; i < (int) exclusionBlocksForBlock.size(); i++)
maxExclusions = (maxExclusions > exclusionBlocksForBlock[i].size() ? maxExclusions : exclusionBlocksForBlock[i].size()); maxExclusions = (maxExclusions > exclusionBlocksForBlock[i].size() ? maxExclusions : exclusionBlocksForBlock[i].size());
defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions); defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
defines["GROUP_SIZE"] = (deviceIsCpu ? "32" : "128");
defines["BUFFER_GROUPS"] = (deviceIsCpu ? "4" : "2"); defines["BUFFER_GROUPS"] = (deviceIsCpu ? "4" : "2");
string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks); string file = (deviceIsCpu ? OpenCLKernelSources::findInteractingBlocks_cpu : OpenCLKernelSources::findInteractingBlocks);
cl::Program interactingBlocksProgram = context.createProgram(file, defines); int groupSize = (deviceIsCpu ? 32 : 128);
findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds"); while (true) {
findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms()); defines["GROUP_SIZE"] = context.intToString(groupSize);
findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer()); cl::Program interactingBlocksProgram = context.createProgram(file, defines);
findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer()); findBlockBoundsKernel = cl::Kernel(interactingBlocksProgram, "findBlockBounds");
findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer()); findBlockBoundsKernel.setArg<cl_int>(0, context.getNumAtoms());
findBlockBoundsKernel.setArg<cl::Buffer>(6, rebuildNeighborList->getDeviceBuffer()); findBlockBoundsKernel.setArg<cl::Buffer>(3, context.getPosq().getDeviceBuffer());
findBlockBoundsKernel.setArg<cl::Buffer>(7, sortedBlocks->getDeviceBuffer()); findBlockBoundsKernel.setArg<cl::Buffer>(4, blockCenter->getDeviceBuffer());
sortBoxDataKernel = cl::Kernel(interactingBlocksProgram, "sortBoxData"); findBlockBoundsKernel.setArg<cl::Buffer>(5, blockBoundingBox->getDeviceBuffer());
sortBoxDataKernel.setArg<cl::Buffer>(0, sortedBlocks->getDeviceBuffer()); findBlockBoundsKernel.setArg<cl::Buffer>(6, rebuildNeighborList->getDeviceBuffer());
sortBoxDataKernel.setArg<cl::Buffer>(1, blockCenter->getDeviceBuffer()); findBlockBoundsKernel.setArg<cl::Buffer>(7, sortedBlocks->getDeviceBuffer());
sortBoxDataKernel.setArg<cl::Buffer>(2, blockBoundingBox->getDeviceBuffer()); sortBoxDataKernel = cl::Kernel(interactingBlocksProgram, "sortBoxData");
sortBoxDataKernel.setArg<cl::Buffer>(3, sortedBlockCenter->getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(0, sortedBlocks->getDeviceBuffer());
sortBoxDataKernel.setArg<cl::Buffer>(4, sortedBlockBoundingBox->getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(1, blockCenter->getDeviceBuffer());
sortBoxDataKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(2, blockBoundingBox->getDeviceBuffer());
sortBoxDataKernel.setArg<cl::Buffer>(6, oldPositions->getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(3, sortedBlockCenter->getDeviceBuffer());
sortBoxDataKernel.setArg<cl::Buffer>(7, interactionCount->getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(4, sortedBlockBoundingBox->getDeviceBuffer());
sortBoxDataKernel.setArg<cl::Buffer>(8, rebuildNeighborList->getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer());
findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions"); sortBoxDataKernel.setArg<cl::Buffer>(6, oldPositions->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(2, interactionCount->getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(7, interactionCount->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(3, interactingTiles->getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(8, rebuildNeighborList->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(4, interactingAtoms->getDeviceBuffer()); findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
findInteractingBlocksKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(2, interactionCount->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl_uint>(6, interactingTiles->getSize()); findInteractingBlocksKernel.setArg<cl::Buffer>(3, interactingTiles->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl_uint>(7, startBlockIndex); findInteractingBlocksKernel.setArg<cl::Buffer>(4, interactingAtoms->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl_uint>(8, numBlocks); findInteractingBlocksKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(9, sortedBlocks->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl_uint>(6, interactingTiles->getSize());
findInteractingBlocksKernel.setArg<cl::Buffer>(10, sortedBlockCenter->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl_uint>(7, startBlockIndex);
findInteractingBlocksKernel.setArg<cl::Buffer>(11, sortedBlockBoundingBox->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl_uint>(8, numBlocks);
findInteractingBlocksKernel.setArg<cl::Buffer>(12, exclusionIndices->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(9, sortedBlocks->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(13, exclusionRowIndices->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(10, sortedBlockCenter->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(14, oldPositions->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(11, sortedBlockBoundingBox->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(15, rebuildNeighborList->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(12, exclusionIndices->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(13, exclusionRowIndices->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(14, oldPositions->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(15, rebuildNeighborList->getDeviceBuffer());
if (findInteractingBlocksKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()) < groupSize) {
// The device can't handle this block size, so reduce it.
groupSize -= 32;
if (groupSize < 32)
throw OpenMMException("Failed to create findInteractingBlocks kernel");
continue;
}
break;
}
} }
} }
......
...@@ -56,10 +56,13 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le ...@@ -56,10 +56,13 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le
unsigned int maxGroupSize = std::min(256, (int) context.getDevice().getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>()); unsigned int maxGroupSize = std::min(256, (int) context.getDevice().getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>());
int maxSharedMem = context.getDevice().getInfo<CL_DEVICE_LOCAL_MEM_SIZE>(); int maxSharedMem = context.getDevice().getInfo<CL_DEVICE_LOCAL_MEM_SIZE>();
unsigned int maxLocalBuffer = (unsigned int) ((maxSharedMem/trait->getDataSize())/2); unsigned int maxLocalBuffer = (unsigned int) ((maxSharedMem/trait->getDataSize())/2);
isShortList = (length <= maxLocalBuffer); unsigned int maxRangeSize = std::min(maxGroupSize, (unsigned int) computeRangeKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()));
for (rangeKernelSize = 1; rangeKernelSize*2 <= maxGroupSize; rangeKernelSize *= 2) unsigned int maxPositionsSize = std::min(maxGroupSize, (unsigned int) computeBucketPositionsKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()));
unsigned int maxShortListSize = shortListKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice());
isShortList = (length <= maxLocalBuffer && length < maxShortListSize);
for (rangeKernelSize = 1; rangeKernelSize*2 <= maxRangeSize; rangeKernelSize *= 2)
; ;
positionsKernelSize = rangeKernelSize; positionsKernelSize = std::min(rangeKernelSize, maxPositionsSize);
sortKernelSize = (isShortList ? rangeKernelSize : rangeKernelSize/2); sortKernelSize = (isShortList ? rangeKernelSize : rangeKernelSize/2);
if (rangeKernelSize > length) if (rangeKernelSize > length)
rangeKernelSize = length; rangeKernelSize = length;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment