Unverified Commit 4177f729 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

Fix launch parameters index select kernel in sparse push (#3524)

parent cb39eae1
...@@ -445,13 +445,13 @@ NDArray SparsePull( ...@@ -445,13 +445,13 @@ NDArray SparsePull(
// and then index select them into place // and then index select them into place
Workspace<DType> filled_response_value(device, ctx, Workspace<DType> filled_response_value(device, ctx,
response_prefix_host.back()*num_feat); response_prefix_host.back()*num_feat);
if (request_prefix_host.back() > 0) { if (response_prefix_host.back() > 0) {
dim3 block(256, 1); dim3 block(256, 1);
while (block.x >= 2*num_feat) { while (block.x >= 2*num_feat) {
block.x /= 2; block.x /= 2;
block.y *= 2; block.y *= 2;
} }
const dim3 grid((request_prefix_host.back()+block.y-1)/block.y); const dim3 grid((response_prefix_host.back()+block.y-1)/block.y);
aten::impl::IndexSelectMultiKernel<<<grid, block, 0, stream>>>( aten::impl::IndexSelectMultiKernel<<<grid, block, 0, stream>>>(
static_cast<const DType*>(local_tensor->data), static_cast<const DType*>(local_tensor->data),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment