Unverified Commit dc06060b authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[CUDA] fixing partition bug (#6001)

parent 02e79a3d
...@@ -329,14 +329,14 @@ std::pair<IdArray, NDArray> GeneratePermutationFromRemainder( ...@@ -329,14 +329,14 @@ std::pair<IdArray, NDArray> GeneratePermutationFromRemainder(
CUDA_CALL(cub::DeviceHistogram::HistogramEven( CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr, hist_workspace_size, proc_id_out.get(), nullptr, hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1, reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1), static_cast<IdType>(0), static_cast<IdType>(num_parts),
static_cast<int>(num_in), stream)); static_cast<int>(num_in), stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size); Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven( CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(), hist_workspace_size, proc_id_out.get(), hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1, reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1), static_cast<IdType>(0), static_cast<IdType>(num_parts),
static_cast<int>(num_in), stream)); static_cast<int>(num_in), stream));
} }
...@@ -504,14 +504,14 @@ std::pair<IdArray, NDArray> GeneratePermutationFromRange( ...@@ -504,14 +504,14 @@ std::pair<IdArray, NDArray> GeneratePermutationFromRange(
CUDA_CALL(cub::DeviceHistogram::HistogramEven( CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr, hist_workspace_size, proc_id_out.get(), nullptr, hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1, reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1), static_cast<IdType>(0), static_cast<IdType>(num_parts),
static_cast<int>(num_in), stream)); static_cast<int>(num_in), stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size); Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven( CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(), hist_workspace_size, proc_id_out.get(), hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1, reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1), static_cast<IdType>(0), static_cast<IdType>(num_parts),
static_cast<int>(num_in), stream)); static_cast<int>(num_in), stream));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment