Unverified Commit dc06060b authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[CUDA] fixing partition bug (#6001)

parent 02e79a3d
......@@ -329,14 +329,14 @@ std::pair<IdArray, NDArray> GeneratePermutationFromRemainder(
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr, hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
static_cast<IdType>(0), static_cast<IdType>(num_parts),
static_cast<int>(num_in), stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
static_cast<IdType>(0), static_cast<IdType>(num_parts),
static_cast<int>(num_in), stream));
}
......@@ -504,14 +504,14 @@ std::pair<IdArray, NDArray> GeneratePermutationFromRange(
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr, hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
static_cast<IdType>(0), static_cast<IdType>(num_parts),
static_cast<int>(num_in), stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
static_cast<IdType>(0), static_cast<IdType>(num_parts),
static_cast<int>(num_in), stream));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment