Unverified Commit eaa82a70 authored by Daniel Cámpora's avatar Daniel Cámpora Committed by GitHub
Browse files

[Bugfix][DSV32] Fix overflow in topk. (#30754)


Signed-off-by: default avatarDaniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent f5f51e59
...@@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill( ...@@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
int rowEnd = rowEnds[rowIdx]; int rowEnd = rowEnds[rowIdx];
// Local pointers to this block // Local pointers to this block
outIndices += rowIdx * topK; outIndices += static_cast<int64_t>(rowIdx) * topK;
logits += rowIdx * stride0; logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>( topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK); nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
...@@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( ...@@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// Local pointers to this block // Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) { if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
outIndices += rowIdx * topK; outIndices += static_cast<int64_t>(rowIdx) * topK;
} else if constexpr (multipleBlocksPerRow) { } else if constexpr (multipleBlocksPerRow) {
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192 const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192 rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize; rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK; outIndices +=
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK; static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
outLogits +=
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
} else if constexpr (mergeBlocks) { } else if constexpr (mergeBlocks) {
rowEnd = numBlocksToMerge * topK; rowEnd = numBlocksToMerge * topK;
indices += rowIdx * numBlocksToMerge * topK; indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
outIndices += rowIdx * topK; outIndices += static_cast<int64_t>(rowIdx) * topK;
} }
logits += rowIdx * stride0; logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort, topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
multipleBlocksPerRow, mergeBlocks>( multipleBlocksPerRow, mergeBlocks>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment