Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
GLM-130B_fastertransformer
Commits
f8a481f8
Commit
f8a481f8
authored
Oct 13, 2023
by
zhouxiang
Browse files
添加dtk中的cub头文件
parent
7b7c64c5
Changes
147
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8656 additions
and
0 deletions
+8656
-0
3rdparty/cub/rocprim/device/detail/device_merge_sort_mergepath.hpp
...cub/rocprim/device/detail/device_merge_sort_mergepath.hpp
+439
-0
3rdparty/cub/rocprim/device/detail/device_partition.hpp
3rdparty/cub/rocprim/device/detail/device_partition.hpp
+897
-0
3rdparty/cub/rocprim/device/detail/device_radix_sort.hpp
3rdparty/cub/rocprim/device/detail/device_radix_sort.hpp
+1070
-0
3rdparty/cub/rocprim/device/detail/device_reduce.hpp
3rdparty/cub/rocprim/device/detail/device_reduce.hpp
+184
-0
3rdparty/cub/rocprim/device/detail/device_reduce_by_key.hpp
3rdparty/cub/rocprim/device/detail/device_reduce_by_key.hpp
+644
-0
3rdparty/cub/rocprim/device/detail/device_scan_by_key.hpp
3rdparty/cub/rocprim/device/detail/device_scan_by_key.hpp
+388
-0
3rdparty/cub/rocprim/device/detail/device_scan_common.hpp
3rdparty/cub/rocprim/device/detail/device_scan_common.hpp
+153
-0
3rdparty/cub/rocprim/device/detail/device_scan_lookback.hpp
3rdparty/cub/rocprim/device/detail/device_scan_lookback.hpp
+222
-0
3rdparty/cub/rocprim/device/detail/device_scan_reduce_then_scan.hpp
...ub/rocprim/device/detail/device_scan_reduce_then_scan.hpp
+469
-0
3rdparty/cub/rocprim/device/detail/device_segmented_radix_sort.hpp
...cub/rocprim/device/detail/device_segmented_radix_sort.hpp
+990
-0
3rdparty/cub/rocprim/device/detail/device_segmented_reduce.hpp
...rty/cub/rocprim/device/detail/device_segmented_reduce.hpp
+166
-0
3rdparty/cub/rocprim/device/detail/device_segmented_scan.hpp
3rdparty/cub/rocprim/device/detail/device_segmented_scan.hpp
+236
-0
3rdparty/cub/rocprim/device/detail/device_transform.hpp
3rdparty/cub/rocprim/device/detail/device_transform.hpp
+154
-0
3rdparty/cub/rocprim/device/detail/lookback_scan_state.hpp
3rdparty/cub/rocprim/device/detail/lookback_scan_state.hpp
+459
-0
3rdparty/cub/rocprim/device/detail/ordered_block_id.hpp
3rdparty/cub/rocprim/device/detail/ordered_block_id.hpp
+87
-0
3rdparty/cub/rocprim/device/detail/uint_fast_div.hpp
3rdparty/cub/rocprim/device/detail/uint_fast_div.hpp
+106
-0
3rdparty/cub/rocprim/device/device_adjacent_difference.hpp
3rdparty/cub/rocprim/device/device_adjacent_difference.hpp
+523
-0
3rdparty/cub/rocprim/device/device_adjacent_difference_config.hpp
.../cub/rocprim/device/device_adjacent_difference_config.hpp
+84
-0
3rdparty/cub/rocprim/device/device_binary_search.hpp
3rdparty/cub/rocprim/device/device_binary_search.hpp
+177
-0
3rdparty/cub/rocprim/device/device_histogram.hpp
3rdparty/cub/rocprim/device/device_histogram.hpp
+1208
-0
No files found.
Too many changes to show.
To preserve performance only
147 of 147+
files are displayed.
Plain diff
Email patch
3rdparty/cub/rocprim/device/detail/device_merge_sort_mergepath.hpp
0 → 100644
View file @
f8a481f8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2022, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_MERGEPATH_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_MERGEPATH_HPP_
#include <iterator>
#include "../../detail/various.hpp"
#include "device_merge_sort.hpp"
#include "device_merge.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
// Load items from input1 and input2 from global memory
template
<
unsigned
int
ItemsPerThread
,
class
KeyT
,
class
InputIterator
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
gmem_to_reg
(
KeyT
(
&
output
)[
ItemsPerThread
],
InputIterator
input1
,
InputIterator
input2
,
unsigned
int
count1
,
unsigned
int
count2
,
bool
IsLastTile
)
{
if
(
IsLastTile
)
{
ROCPRIM_UNROLL
for
(
unsigned
int
item
=
0
;
item
<
ItemsPerThread
;
++
item
)
{
unsigned
int
idx
=
rocprim
::
flat_block_size
()
*
item
+
threadIdx
.
x
;
if
(
idx
<
count1
+
count2
)
{
output
[
item
]
=
(
idx
<
count1
)
?
input1
[
idx
]
:
input2
[
idx
-
count1
];
}
}
}
else
{
ROCPRIM_UNROLL
for
(
unsigned
int
item
=
0
;
item
<
ItemsPerThread
;
++
item
)
{
unsigned
int
idx
=
rocprim
::
flat_block_size
()
*
item
+
threadIdx
.
x
;
output
[
item
]
=
(
idx
<
count1
)
?
input1
[
idx
]
:
input2
[
idx
-
count1
];
}
}
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
KeyT
,
class
OutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
reg_to_shared
(
OutputIterator
output
,
KeyT
(
&
input
)[
ItemsPerThread
])
{
ROCPRIM_UNROLL
for
(
unsigned
int
item
=
0
;
item
<
ItemsPerThread
;
++
item
)
{
unsigned
int
idx
=
BlockSize
*
item
+
threadIdx
.
x
;
output
[
idx
]
=
input
[
item
];
}
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
,
class
OffsetT
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
auto
block_merge_process_tile
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
const
OffsetT
input_size
,
const
OffsetT
sorted_block_size
,
BinaryFunction
compare_function
,
const
OffsetT
*
merge_partitions
)
->
std
::
enable_if_t
<
(
!
std
::
is_trivially_copyable
<
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
>::
value
||
rocprim
::
is_floating_point
<
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
>::
value
||
std
::
is_integral
<
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
>::
value
),
void
>
{
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
value_type
=
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
;
constexpr
bool
with_values
=
!
std
::
is_same
<
value_type
,
::
rocprim
::
empty_type
>::
value
;
constexpr
unsigned
int
items_per_tile
=
BlockSize
*
ItemsPerThread
;
using
block_store
=
block_store_impl
<
with_values
,
BlockSize
,
ItemsPerThread
,
key_type
,
value_type
>
;
using
keys_storage_
=
key_type
[
items_per_tile
+
1
];
using
values_storage_
=
value_type
[
items_per_tile
+
1
];
ROCPRIM_SHARED_MEMORY
union
{
typename
block_store
::
storage_type
store
;
detail
::
raw_storage
<
keys_storage_
>
keys
;
detail
::
raw_storage
<
values_storage_
>
values
;
}
storage
;
auto
&
keys_shared
=
storage
.
keys
.
get
();
auto
&
values_shared
=
storage
.
values
.
get
();
const
unsigned
short
flat_id
=
block_thread_id
<
0
>
();
const
unsigned
int
flat_block_id
=
block_id
<
0
>
();
const
bool
IsIncompleteTile
=
flat_block_id
==
(
input_size
/
items_per_tile
);
const
OffsetT
partition_beg
=
merge_partitions
[
flat_block_id
];
const
OffsetT
partition_end
=
merge_partitions
[
flat_block_id
+
1
];
const
unsigned
int
merged_tiles_number
=
sorted_block_size
/
items_per_tile
;
const
unsigned
int
target_merged_tiles_number
=
merged_tiles_number
*
2
;
const
unsigned
int
mask
=
target_merged_tiles_number
-
1
;
const
unsigned
int
tilegroup_start_id
=
~
mask
&
flat_block_id
;
const
OffsetT
tilegroup_start
=
items_per_tile
*
tilegroup_start_id
;
// Tile-group starts here
const
OffsetT
diag
=
items_per_tile
*
flat_block_id
-
tilegroup_start
;
const
OffsetT
keys1_beg
=
partition_beg
;
OffsetT
keys1_end
=
partition_end
;
const
OffsetT
keys2_beg
=
rocprim
::
min
(
input_size
,
2
*
tilegroup_start
+
sorted_block_size
+
diag
-
partition_beg
);
OffsetT
keys2_end
=
rocprim
::
min
(
input_size
,
2
*
tilegroup_start
+
sorted_block_size
+
diag
+
items_per_tile
-
partition_end
);
if
(
mask
==
(
mask
&
flat_block_id
))
// If last tile in the tile-group
{
keys1_end
=
rocprim
::
min
(
input_size
,
tilegroup_start
+
sorted_block_size
);
keys2_end
=
rocprim
::
min
(
input_size
,
tilegroup_start
+
sorted_block_size
*
2
);
}
// Number of keys per tile
const
unsigned
int
num_keys1
=
static_cast
<
unsigned
int
>
(
keys1_end
-
keys1_beg
);
const
unsigned
int
num_keys2
=
static_cast
<
unsigned
int
>
(
keys2_end
-
keys2_beg
);
// Load keys1 & keys2
key_type
keys
[
ItemsPerThread
];
gmem_to_reg
<
ItemsPerThread
>
(
keys
,
keys_input
+
keys1_beg
,
keys_input
+
keys2_beg
,
num_keys1
,
num_keys2
,
IsIncompleteTile
);
// Load keys into shared memory
reg_to_shared
<
BlockSize
,
ItemsPerThread
>
(
keys_shared
,
keys
);
value_type
values
[
ItemsPerThread
];
if
ROCPRIM_IF_CONSTEXPR
(
with_values
){
gmem_to_reg
<
ItemsPerThread
>
(
values
,
values_input
+
keys1_beg
,
values_input
+
keys2_beg
,
num_keys1
,
num_keys2
,
IsIncompleteTile
);
}
rocprim
::
syncthreads
();
const
unsigned
int
diag0_local
=
rocprim
::
min
(
num_keys1
+
num_keys2
,
ItemsPerThread
*
flat_id
);
const
unsigned
int
keys1_beg_local
=
merge_path
(
keys_shared
,
&
keys_shared
[
num_keys1
],
num_keys1
,
num_keys2
,
diag0_local
,
compare_function
);
const
unsigned
int
keys1_end_local
=
num_keys1
;
const
unsigned
int
keys2_beg_local
=
diag0_local
-
keys1_beg_local
;
const
unsigned
int
keys2_end_local
=
num_keys2
;
range_t
range_local
=
{
keys1_beg_local
,
keys1_end_local
,
keys2_beg_local
+
keys1_end_local
,
keys2_end_local
+
keys1_end_local
};
unsigned
int
indices
[
ItemsPerThread
];
serial_merge
(
keys_shared
,
keys
,
indices
,
range_local
,
compare_function
);
if
ROCPRIM_IF_CONSTEXPR
(
with_values
){
reg_to_shared
<
BlockSize
,
ItemsPerThread
>
(
values_shared
,
values
);
rocprim
::
syncthreads
();
ROCPRIM_UNROLL
for
(
unsigned
int
item
=
0
;
item
<
ItemsPerThread
;
++
item
)
{
values
[
item
]
=
values_shared
[
indices
[
item
]];
}
rocprim
::
syncthreads
();
}
const
OffsetT
offset
=
flat_block_id
*
items_per_tile
;
block_store
().
store
(
offset
,
input_size
-
offset
,
IsIncompleteTile
,
keys_output
,
values_output
,
keys
,
values
,
storage
.
store
);
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
,
class
OffsetT
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
auto
block_merge_process_tile
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
const
OffsetT
input_size
,
const
OffsetT
sorted_block_size
,
BinaryFunction
compare_function
,
const
OffsetT
*
merge_partitions
)
->
std
::
enable_if_t
<
(
std
::
is_trivially_copyable
<
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
>::
value
&&
!
rocprim
::
is_floating_point
<
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
>::
value
&&
!
std
::
is_integral
<
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
>::
value
),
void
>
{
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
value_type
=
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
;
constexpr
bool
with_values
=
!
std
::
is_same
<
value_type
,
::
rocprim
::
empty_type
>::
value
;
constexpr
unsigned
int
items_per_tile
=
BlockSize
*
ItemsPerThread
;
using
block_store
=
block_store_impl
<
false
,
BlockSize
,
ItemsPerThread
,
key_type
,
value_type
>
;
using
keys_storage_
=
key_type
[
items_per_tile
+
1
];
using
values_storage_
=
value_type
[
items_per_tile
+
1
];
ROCPRIM_SHARED_MEMORY
union
{
typename
block_store
::
storage_type
store
;
detail
::
raw_storage
<
keys_storage_
>
keys
;
detail
::
raw_storage
<
values_storage_
>
values
;
}
storage
;
auto
&
keys_shared
=
storage
.
keys
.
get
();
auto
&
values_shared
=
storage
.
values
.
get
();
const
unsigned
short
flat_id
=
block_thread_id
<
0
>
();
const
unsigned
int
flat_block_id
=
block_id
<
0
>
();
const
bool
IsIncompleteTile
=
flat_block_id
==
(
input_size
/
items_per_tile
);
const
OffsetT
partition_beg
=
merge_partitions
[
flat_block_id
];
const
OffsetT
partition_end
=
merge_partitions
[
flat_block_id
+
1
];
const
unsigned
int
merged_tiles_number
=
sorted_block_size
/
items_per_tile
;
const
unsigned
int
target_merged_tiles_number
=
merged_tiles_number
*
2
;
const
unsigned
int
mask
=
target_merged_tiles_number
-
1
;
const
unsigned
int
tilegroup_start_id
=
~
mask
&
flat_block_id
;
const
OffsetT
tilegroup_start
=
items_per_tile
*
tilegroup_start_id
;
// Tile-group starts here
const
OffsetT
diag
=
items_per_tile
*
flat_block_id
-
tilegroup_start
;
const
OffsetT
keys1_beg
=
partition_beg
;
OffsetT
keys1_end
=
partition_end
;
const
OffsetT
keys2_beg
=
rocprim
::
min
(
input_size
,
2
*
tilegroup_start
+
sorted_block_size
+
diag
-
partition_beg
);
OffsetT
keys2_end
=
rocprim
::
min
(
input_size
,
2
*
tilegroup_start
+
sorted_block_size
+
diag
+
items_per_tile
-
partition_end
);
if
(
mask
==
(
mask
&
flat_block_id
))
// If last tile in the tile-group
{
keys1_end
=
rocprim
::
min
(
input_size
,
tilegroup_start
+
sorted_block_size
);
keys2_end
=
rocprim
::
min
(
input_size
,
tilegroup_start
+
sorted_block_size
*
2
);
}
// Number of keys per tile
const
unsigned
int
num_keys1
=
static_cast
<
unsigned
int
>
(
keys1_end
-
keys1_beg
);
const
unsigned
int
num_keys2
=
static_cast
<
unsigned
int
>
(
keys2_end
-
keys2_beg
);
// Load keys1 & keys2
key_type
keys
[
ItemsPerThread
];
gmem_to_reg
<
ItemsPerThread
>
(
keys
,
keys_input
+
keys1_beg
,
keys_input
+
keys2_beg
,
num_keys1
,
num_keys2
,
IsIncompleteTile
);
// Load keys into shared memory
reg_to_shared
<
BlockSize
,
ItemsPerThread
>
(
keys_shared
,
keys
);
rocprim
::
syncthreads
();
const
unsigned
int
diag0_local
=
rocprim
::
min
(
num_keys1
+
num_keys2
,
ItemsPerThread
*
flat_id
);
const
unsigned
int
keys1_beg_local
=
merge_path
(
keys_shared
,
&
keys_shared
[
num_keys1
],
num_keys1
,
num_keys2
,
diag0_local
,
compare_function
);
const
unsigned
int
keys1_end_local
=
num_keys1
;
const
unsigned
int
keys2_beg_local
=
diag0_local
-
keys1_beg_local
;
const
unsigned
int
keys2_end_local
=
num_keys2
;
range_t
range_local
=
{
keys1_beg_local
,
keys1_end_local
,
keys2_beg_local
+
keys1_end_local
,
keys2_end_local
+
keys1_end_local
};
unsigned
int
indices
[
ItemsPerThread
];
serial_merge
(
keys_shared
,
keys
,
indices
,
range_local
,
compare_function
);
if
ROCPRIM_IF_CONSTEXPR
(
with_values
)
{
const
ValuesInputIterator
input1
=
values_input
+
keys1_beg
;
const
ValuesInputIterator
input2
=
values_input
+
keys2_beg
;
if
(
IsIncompleteTile
)
{
ROCPRIM_UNROLL
for
(
unsigned
int
item
=
0
;
item
<
ItemsPerThread
;
++
item
)
{
unsigned
int
idx
=
BlockSize
*
item
+
threadIdx
.
x
;
if
(
idx
<
num_keys1
)
{
values_shared
[
idx
]
=
input1
[
idx
];
}
else
if
(
idx
-
num_keys1
<
num_keys2
)
{
values_shared
[
idx
]
=
input2
[
idx
-
num_keys1
];
}
}
}
else
{
ROCPRIM_UNROLL
for
(
unsigned
int
item
=
0
;
item
<
ItemsPerThread
;
++
item
)
{
unsigned
int
idx
=
BlockSize
*
item
+
threadIdx
.
x
;
if
(
idx
<
num_keys1
)
{
values_shared
[
idx
]
=
input1
[
idx
];
}
else
{
values_shared
[
idx
]
=
input2
[
idx
-
num_keys1
];
}
}
}
rocprim
::
syncthreads
();
const
OffsetT
offset
=
(
flat_block_id
*
items_per_tile
)
+
(
threadIdx
.
x
*
ItemsPerThread
);
ROCPRIM_UNROLL
for
(
unsigned
int
item
=
0
;
item
<
ItemsPerThread
;
++
item
)
{
values_output
[
offset
+
item
]
=
values_shared
[
indices
[
item
]];
}
rocprim
::
syncthreads
();
}
const
OffsetT
offset
=
flat_block_id
*
items_per_tile
;
value_type
values
[
ItemsPerThread
];
block_store
().
store
(
offset
,
input_size
-
offset
,
IsIncompleteTile
,
keys_output
,
values_output
,
keys
,
values
,
storage
.
store
);
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
,
class
OffsetT
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
block_merge_kernel_impl
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
const
OffsetT
input_size
,
const
OffsetT
sorted_block_size
,
BinaryFunction
compare_function
,
const
OffsetT
*
merge_partitions
)
{
block_merge_process_tile
<
BlockSize
,
ItemsPerThread
>
(
keys_input
,
keys_output
,
values_input
,
values_output
,
input_size
,
sorted_block_size
,
compare_function
,
merge_partitions
);
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_MERGEPATH_HPP_
\ No newline at end of file
3rdparty/cub/rocprim/device/detail/device_partition.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_
#include <type_traits>
#include <iterator>
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../types.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_store.hpp"
#include "../../block/block_scan.hpp"
#include "../../block/block_discontinuity.hpp"
#include "device_scan_lookback.hpp"
#include "lookback_scan_state.hpp"
#include "ordered_block_id.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
template
<
class
T
,
class
LookbackScanState
>
class
offset_lookback_scan_prefix_op
:
public
lookback_scan_prefix_op
<
T
,
::
rocprim
::
plus
<
T
>
,
LookbackScanState
>
{
using
base_type
=
lookback_scan_prefix_op
<
T
,
::
rocprim
::
plus
<
T
>
,
LookbackScanState
>
;
using
binary_op_type
=
::
rocprim
::
plus
<
T
>
;
public:
struct
storage_type
{
T
block_reduction
;
T
exclusive_prefix
;
};
ROCPRIM_DEVICE
ROCPRIM_INLINE
offset_lookback_scan_prefix_op
(
unsigned
int
block_id
,
LookbackScanState
&
state
,
storage_type
&
storage
)
:
base_type
(
block_id
,
binary_op_type
(),
state
),
storage_
(
storage
)
{
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
~
offset_lookback_scan_prefix_op
()
=
default
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
T
operator
()(
T
reduction
)
{
auto
prefix
=
base_type
::
operator
()(
reduction
);
if
(
::
rocprim
::
lane_id
()
==
0
)
{
storage_
.
block_reduction
=
reduction
;
storage_
.
exclusive_prefix
=
prefix
;
}
return
prefix
;
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
T
get_reduction
()
const
{
return
storage_
.
block_reduction
;
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
T
get_exclusive_prefix
()
const
{
return
storage_
.
exclusive_prefix
;
}
private:
storage_type
&
storage_
;
};
enum
class
select_method
{
flag
=
0
,
predicate
=
1
,
unique
=
2
};
template
<
select_method
SelectMethod
,
unsigned
int
BlockSize
,
class
BlockLoadFlagsType
,
class
BlockDiscontinuityType
,
class
InputIterator
,
class
FlagIterator
,
class
ValueType
,
unsigned
int
ItemsPerThread
,
class
UnaryPredicate
,
class
InequalityOp
,
class
StorageType
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
partition_block_load_flags
(
InputIterator
/* block_predecessor */
,
FlagIterator
block_flags
,
ValueType
(
&
/* values */
)[
ItemsPerThread
],
bool
(
&
is_selected
)[
ItemsPerThread
],
UnaryPredicate
/* predicate */
,
InequalityOp
/* inequality_op */
,
StorageType
&
storage
,
const
unsigned
int
/* block_id */
,
const
unsigned
int
/* block_thread_id */
,
const
bool
is_last_block
,
const
unsigned
int
valid_in_last_block
)
->
typename
std
::
enable_if
<
SelectMethod
==
select_method
::
flag
>::
type
{
if
(
is_last_block
)
// last block
{
BlockLoadFlagsType
()
.
load
(
block_flags
,
is_selected
,
valid_in_last_block
,
false
,
storage
.
load_flags
);
}
else
{
BlockLoadFlagsType
()
.
load
(
block_flags
,
is_selected
,
storage
.
load_flags
);
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
}
template
<
select_method
SelectMethod
,
unsigned
int
BlockSize
,
class
BlockLoadFlagsType
,
class
BlockDiscontinuityType
,
class
InputIterator
,
class
FlagIterator
,
class
ValueType
,
unsigned
int
ItemsPerThread
,
class
UnaryPredicate
,
class
InequalityOp
,
class
StorageType
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
partition_block_load_flags
(
InputIterator
/* block_predecessor */
,
FlagIterator
/* block_flags */
,
ValueType
(
&
values
)[
ItemsPerThread
],
bool
(
&
is_selected
)[
ItemsPerThread
],
UnaryPredicate
predicate
,
InequalityOp
/* inequality_op */
,
StorageType
&
/* storage */
,
const
unsigned
int
/* block_id */
,
const
unsigned
int
block_thread_id
,
const
bool
is_last_block
,
const
unsigned
int
valid_in_last_block
)
->
typename
std
::
enable_if
<
SelectMethod
==
select_method
::
predicate
>::
type
{
if
(
is_last_block
)
// last block
{
const
auto
offset
=
block_thread_id
*
ItemsPerThread
;
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
if
((
offset
+
i
)
<
valid_in_last_block
)
{
is_selected
[
i
]
=
predicate
(
values
[
i
]);
}
else
{
is_selected
[
i
]
=
false
;
}
}
}
else
{
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
is_selected
[
i
]
=
predicate
(
values
[
i
]);
}
}
}
// This wrapper processes only part of items and flags (valid_count - 1)th item (for tails)
// and (valid_count)th item (for heads), all items after valid_count are unflagged.
template
<
class
InequalityOp
>
struct
guarded_inequality_op
{
InequalityOp
inequality_op
;
unsigned
int
valid_count
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
guarded_inequality_op
(
InequalityOp
inequality_op
,
unsigned
int
valid_count
)
:
inequality_op
(
inequality_op
),
valid_count
(
valid_count
)
{}
template
<
class
T
,
class
U
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
operator
()(
const
T
&
a
,
const
U
&
b
,
unsigned
int
b_index
)
{
return
(
b_index
<
valid_count
&&
inequality_op
(
a
,
b
));
}
};
template
<
select_method
SelectMethod
,
unsigned
int
BlockSize
,
class
BlockLoadFlagsType
,
class
BlockDiscontinuityType
,
class
InputIterator
,
class
FlagIterator
,
class
ValueType
,
unsigned
int
ItemsPerThread
,
class
UnaryPredicate
,
class
InequalityOp
,
class
StorageType
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
partition_block_load_flags
(
InputIterator
block_predecessor
,
FlagIterator
/* block_flags */
,
ValueType
(
&
values
)[
ItemsPerThread
],
bool
(
&
is_selected
)[
ItemsPerThread
],
UnaryPredicate
/* predicate */
,
InequalityOp
inequality_op
,
StorageType
&
storage
,
const
unsigned
int
block_id
,
const
unsigned
int
block_thread_id
,
const
bool
is_last_block
,
const
unsigned
int
valid_in_last_block
)
->
typename
std
::
enable_if
<
SelectMethod
==
select_method
::
unique
>::
type
{
if
(
block_id
>
0
)
{
const
ValueType
predecessor
=
*
block_predecessor
;
if
(
is_last_block
)
{
BlockDiscontinuityType
()
.
flag_heads
(
is_selected
,
predecessor
,
values
,
guarded_inequality_op
<
InequalityOp
>
(
inequality_op
,
valid_in_last_block
),
storage
.
discontinuity_values
);
}
else
{
BlockDiscontinuityType
()
.
flag_heads
(
is_selected
,
predecessor
,
values
,
inequality_op
,
storage
.
discontinuity_values
);
}
}
else
{
if
(
is_last_block
)
{
BlockDiscontinuityType
()
.
flag_heads
(
is_selected
,
values
,
guarded_inequality_op
<
InequalityOp
>
(
inequality_op
,
valid_in_last_block
),
storage
.
discontinuity_values
);
}
else
{
BlockDiscontinuityType
()
.
flag_heads
(
is_selected
,
values
,
inequality_op
,
storage
.
discontinuity_values
);
}
}
// Set is_selected for invalid items to false
if
(
is_last_block
)
{
const
auto
offset
=
block_thread_id
*
ItemsPerThread
;
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
if
((
offset
+
i
)
>=
valid_in_last_block
)
{
is_selected
[
i
]
=
false
;
}
}
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
}
template
<
select_method
SelectMethod
,
unsigned
int
BlockSize
,
class
BlockLoadFlagsType
,
class
BlockDiscontinuityType
,
class
InputIterator
,
class
FlagIterator
,
class
ValueType
,
unsigned
int
ItemsPerThread
,
class
FirstUnaryPredicate
,
class
SecondUnaryPredicate
,
class
InequalityOp
,
class
StorageType
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
partition_block_load_flags
(
InputIterator
/*block_predecessor*/
,
FlagIterator
/* block_flags */
,
ValueType
(
&
values
)[
ItemsPerThread
],
bool
(
&
is_selected
)[
2
][
ItemsPerThread
],
FirstUnaryPredicate
select_first_part_op
,
SecondUnaryPredicate
select_second_part_op
,
InequalityOp
/*inequality_op*/
,
StorageType
&
/*storage*/
,
const
unsigned
int
/*block_id*/
,
const
unsigned
int
block_thread_id
,
const
bool
is_last_block
,
const
unsigned
int
valid_in_last_block
)
{
if
(
is_last_block
)
{
const
auto
offset
=
block_thread_id
*
ItemsPerThread
;
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
if
((
offset
+
i
)
<
valid_in_last_block
)
{
is_selected
[
0
][
i
]
=
select_first_part_op
(
values
[
i
]);
is_selected
[
1
][
i
]
=
!
is_selected
[
0
][
i
]
&&
select_second_part_op
(
values
[
i
]);
}
else
{
is_selected
[
0
][
i
]
=
false
;
is_selected
[
1
][
i
]
=
false
;
}
}
}
else
{
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
is_selected
[
0
][
i
]
=
select_first_part_op
(
values
[
i
]);
is_selected
[
1
][
i
]
=
!
is_selected
[
0
][
i
]
&&
select_second_part_op
(
values
[
i
]);
}
}
}
template
<
bool
OnlySelected
,
unsigned
int
BlockSize
,
class
ValueType
,
unsigned
int
ItemsPerThread
,
class
OffsetType
,
class
OutputIterator
,
class
ScatterStorageType
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
partition_scatter
(
ValueType
(
&
values
)[
ItemsPerThread
],
bool
(
&
is_selected
)[
ItemsPerThread
],
OffsetType
(
&
output_indices
)[
ItemsPerThread
],
OutputIterator
output
,
const
size_t
size
,
const
OffsetType
selected_prefix
,
const
OffsetType
selected_in_block
,
ScatterStorageType
&
storage
,
const
unsigned
int
flat_block_id
,
const
unsigned
int
flat_block_thread_id
,
const
bool
is_last_block
,
const
unsigned
int
valid_in_last_block
,
size_t
*
/* prev_selected_count */
)
->
typename
std
::
enable_if
<!
OnlySelected
>::
type
{
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
// Scatter selected/rejected values to shared memory
auto
scatter_storage
=
storage
.
get
();
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
unsigned
int
item_index
=
(
flat_block_thread_id
*
ItemsPerThread
)
+
i
;
unsigned
int
selected_item_index
=
output_indices
[
i
]
-
selected_prefix
;
unsigned
int
rejected_item_index
=
(
item_index
-
selected_item_index
)
+
selected_in_block
;
// index of item in scatter_storage
unsigned
int
scatter_index
=
is_selected
[
i
]
?
selected_item_index
:
rejected_item_index
;
scatter_storage
[
scatter_index
]
=
values
[
i
];
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
unsigned
int
item_index
=
(
i
*
BlockSize
)
+
flat_block_thread_id
;
unsigned
int
selected_item_index
=
item_index
;
unsigned
int
rejected_item_index
=
item_index
-
selected_in_block
;
// number of values rejected in previous blocks
unsigned
int
rejected_prefix
=
(
flat_block_id
*
items_per_block
)
-
selected_prefix
;
// destination index of item scatter_storage[item_index] in output
OffsetType
scatter_index
=
item_index
<
selected_in_block
?
selected_prefix
+
selected_item_index
:
size
-
(
rejected_prefix
+
rejected_item_index
+
1
);
// last block can store only valid_in_last_block items
if
(
!
is_last_block
||
item_index
<
valid_in_last_block
)
{
output
[
scatter_index
]
=
scatter_storage
[
item_index
];
}
}
}
template
<
bool
OnlySelected
,
unsigned
int
BlockSize
,
class
ValueType
,
unsigned
int
ItemsPerThread
,
class
OffsetType
,
class
OutputIterator
,
class
ScatterStorageType
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
partition_scatter
(
ValueType
(
&
values
)[
ItemsPerThread
],
bool
(
&
is_selected
)[
ItemsPerThread
],
OffsetType
(
&
output_indices
)[
ItemsPerThread
],
OutputIterator
output
,
const
size_t
size
,
const
OffsetType
selected_prefix
,
const
OffsetType
selected_in_block
,
ScatterStorageType
&
storage
,
const
unsigned
int
flat_block_id
,
const
unsigned
int
flat_block_thread_id
,
const
bool
is_last_block
,
const
unsigned
int
valid_in_last_block
,
size_t
*
prev_selected_count
)
->
typename
std
::
enable_if
<
OnlySelected
>::
type
{
(
void
)
size
;
(
void
)
storage
;
(
void
)
flat_block_id
;
(
void
)
flat_block_thread_id
;
(
void
)
valid_in_last_block
;
if
(
selected_in_block
>
BlockSize
)
{
// Scatter selected values to shared memory
auto
scatter_storage
=
storage
.
get
();
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
unsigned
int
scatter_index
=
output_indices
[
i
]
-
selected_prefix
;
if
(
is_selected
[
i
])
{
scatter_storage
[
scatter_index
]
=
values
[
i
];
}
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
// Coalesced write from shared memory to global memory
for
(
unsigned
int
i
=
flat_block_thread_id
;
i
<
selected_in_block
;
i
+=
BlockSize
)
{
output
[
prev_selected_count
[
0
]
+
selected_prefix
+
i
]
=
scatter_storage
[
i
];
}
}
else
{
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
if
(
!
is_last_block
||
output_indices
[
i
]
<
(
selected_prefix
+
selected_in_block
))
{
if
(
is_selected
[
i
])
{
output
[
prev_selected_count
[
0
]
+
output_indices
[
i
]]
=
values
[
i
];
}
}
}
}
}
template
<
bool
OnlySelected
,
unsigned
int
BlockSize
,
class
ValueType
,
unsigned
int
ItemsPerThread
,
class
OffsetType
,
class
OutputType
,
class
ScatterStorageType
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
partition_scatter
(
ValueType
(
&
values
)[
ItemsPerThread
],
bool
(
&
is_selected
)[
2
][
ItemsPerThread
],
OffsetType
(
&
output_indices
)[
ItemsPerThread
],
OutputType
output
,
const
size_t
/*size*/
,
const
OffsetType
selected_prefix
,
const
OffsetType
selected_in_block
,
ScatterStorageType
&
storage
,
const
unsigned
int
flat_block_id
,
const
unsigned
int
flat_block_thread_id
,
const
bool
is_last_block
,
const
unsigned
int
valid_in_last_block
,
size_t
*
/* prev_selected_count */
)
{
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
auto
scatter_storage
=
storage
.
get
();
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
const
unsigned
int
first_selected_item_index
=
output_indices
[
i
].
x
-
selected_prefix
.
x
;
const
unsigned
int
second_selected_item_index
=
output_indices
[
i
].
y
-
selected_prefix
.
y
+
selected_in_block
.
x
;
unsigned
int
scatter_index
{};
if
(
is_selected
[
0
][
i
])
{
scatter_index
=
first_selected_item_index
;
}
else
if
(
is_selected
[
1
][
i
])
{
scatter_index
=
second_selected_item_index
;
}
else
{
const
unsigned
int
item_index
=
(
flat_block_thread_id
*
ItemsPerThread
)
+
i
;
const
unsigned
int
unselected_item_index
=
(
item_index
-
first_selected_item_index
-
second_selected_item_index
)
+
2
*
selected_in_block
.
x
+
selected_in_block
.
y
;
scatter_index
=
unselected_item_index
;
}
scatter_storage
[
scatter_index
]
=
values
[
i
];
}
::
rocprim
::
syncthreads
();
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
const
unsigned
int
item_index
=
(
i
*
BlockSize
)
+
flat_block_thread_id
;
if
(
!
is_last_block
||
item_index
<
valid_in_last_block
)
{
if
(
item_index
<
selected_in_block
.
x
)
{
get
<
0
>
(
output
)[
item_index
+
selected_prefix
.
x
]
=
scatter_storage
[
item_index
];
}
else
if
(
item_index
<
selected_in_block
.
x
+
selected_in_block
.
y
)
{
get
<
1
>
(
output
)[
item_index
-
selected_in_block
.
x
+
selected_prefix
.
y
]
=
scatter_storage
[
item_index
];
}
else
{
const
unsigned
int
all_items_in_previous_blocks
=
items_per_block
*
flat_block_id
;
const
unsigned
int
unselected_items_in_previous_blocks
=
all_items_in_previous_blocks
-
selected_prefix
.
x
-
selected_prefix
.
y
;
const
unsigned
int
output_index
=
item_index
+
unselected_items_in_previous_blocks
-
selected_in_block
.
x
-
selected_in_block
.
y
;
get
<
2
>
(
output
)[
output_index
]
=
scatter_storage
[
item_index
];
}
}
}
}
template
<
unsigned
int
items_per_thread
,
class
offset_type
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
convert_selected_to_indices
(
offset_type
(
&
output_indices
)[
items_per_thread
],
bool
(
&
is_selected
)[
items_per_thread
])
{
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
i
++
)
{
output_indices
[
i
]
=
is_selected
[
i
]
?
1
:
0
;
}
}
template
<
unsigned
int
items_per_thread
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
convert_selected_to_indices
(
uint2
(
&
output_indices
)[
items_per_thread
],
bool
(
&
is_selected
)[
2
][
items_per_thread
])
{
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
i
++
)
{
output_indices
[
i
].
x
=
is_selected
[
0
][
i
]
?
1
:
0
;
output_indices
[
i
].
y
=
is_selected
[
1
][
i
]
?
1
:
0
;
}
}
template
<
class
OffsetT
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
store_selected_count
(
size_t
*
selected_count
,
size_t
*
prev_selected_count
,
const
OffsetT
selected_prefix
,
const
OffsetT
selected_in_block
)
{
selected_count
[
0
]
=
prev_selected_count
[
0
]
+
selected_prefix
+
selected_in_block
;
}
template
<
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
store_selected_count
(
size_t
*
selected_count
,
size_t
*
prev_selected_count
,
const
uint2
selected_prefix
,
const
uint2
selected_in_block
)
{
selected_count
[
0
]
=
prev_selected_count
[
0
]
+
selected_prefix
.
x
+
selected_in_block
.
x
;
selected_count
[
1
]
=
prev_selected_count
[
1
]
+
selected_prefix
.
y
+
selected_in_block
.
y
;
}
template
<
select_method
SelectMethod
,
bool
OnlySelected
,
class
Config
,
class
KeyIterator
,
class
ValueIterator
,
// Can be rocprim::empty_type* if key only
class
FlagIterator
,
class
OutputKeyIterator
,
class
OutputValueIterator
,
class
InequalityOp
,
class
OffsetLookbackScanState
,
class
...
UnaryPredicates
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
partition_kernel_impl
(
KeyIterator
keys_input
,
ValueIterator
values_input
,
FlagIterator
flags
,
OutputKeyIterator
keys_output
,
OutputValueIterator
values_output
,
size_t
*
selected_count
,
size_t
*
prev_selected_count
,
const
size_t
size
,
InequalityOp
inequality_op
,
OffsetLookbackScanState
offset_scan_state
,
const
unsigned
int
number_of_blocks
,
ordered_block_id
<
unsigned
int
>
ordered_bid
,
UnaryPredicates
...
predicates
)
{
constexpr
auto
block_size
=
Config
::
block_size
;
constexpr
auto
items_per_thread
=
Config
::
items_per_thread
;
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
using
offset_type
=
typename
OffsetLookbackScanState
::
value_type
;
using
key_type
=
typename
std
::
iterator_traits
<
KeyIterator
>::
value_type
;
using
value_type
=
typename
std
::
iterator_traits
<
ValueIterator
>::
value_type
;
// Block primitives
using
block_load_key_type
=
::
rocprim
::
block_load
<
key_type
,
block_size
,
items_per_thread
,
Config
::
key_block_load_method
>
;
using
block_load_value_type
=
::
rocprim
::
block_load
<
value_type
,
block_size
,
items_per_thread
,
Config
::
value_block_load_method
>
;
using
block_load_flag_type
=
::
rocprim
::
block_load
<
bool
,
block_size
,
items_per_thread
,
Config
::
flag_block_load_method
>
;
using
block_scan_offset_type
=
::
rocprim
::
block_scan
<
offset_type
,
block_size
,
Config
::
block_scan_method
>
;
using
block_discontinuity_key_type
=
::
rocprim
::
block_discontinuity
<
key_type
,
block_size
>
;
using
order_bid_type
=
ordered_block_id
<
unsigned
int
>
;
// Offset prefix operation type
using
offset_scan_prefix_op_type
=
offset_lookback_scan_prefix_op
<
offset_type
,
OffsetLookbackScanState
>
;
// Memory required for 2-phase scatter
using
exchange_keys_storage_type
=
key_type
[
items_per_block
];
using
raw_exchange_keys_storage_type
=
typename
detail
::
raw_storage
<
exchange_keys_storage_type
>
;
using
exchange_values_storage_type
=
value_type
[
items_per_block
];
using
raw_exchange_values_storage_type
=
typename
detail
::
raw_storage
<
exchange_values_storage_type
>
;
using
is_selected_type
=
std
::
conditional_t
<
sizeof
...(
UnaryPredicates
)
==
1
,
bool
[
items_per_thread
],
bool
[
sizeof
...(
UnaryPredicates
)][
items_per_thread
]
>
;
ROCPRIM_SHARED_MEMORY
struct
{
typename
order_bid_type
::
storage_type
ordered_bid
;
union
{
raw_exchange_keys_storage_type
exchange_keys
;
raw_exchange_values_storage_type
exchange_values
;
typename
block_load_key_type
::
storage_type
load_keys
;
typename
block_load_value_type
::
storage_type
load_values
;
typename
block_load_flag_type
::
storage_type
load_flags
;
typename
block_discontinuity_key_type
::
storage_type
discontinuity_values
;
typename
block_scan_offset_type
::
storage_type
scan_offsets
;
};
}
storage
;
const
auto
flat_block_thread_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
auto
flat_block_id
=
ordered_bid
.
get
(
flat_block_thread_id
,
storage
.
ordered_bid
);
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_block
;
const
auto
valid_in_last_block
=
size
-
items_per_block
*
(
number_of_blocks
-
1
);
key_type
keys
[
items_per_thread
];
is_selected_type
is_selected
;
offset_type
output_indices
[
items_per_thread
];
// Load input values into values
const
bool
is_last_block
=
flat_block_id
==
(
number_of_blocks
-
1
);
if
(
is_last_block
)
// last block
{
block_load_key_type
()
.
load
(
keys_input
+
block_offset
,
keys
,
valid_in_last_block
,
storage
.
load_keys
);
}
else
{
block_load_key_type
()
.
load
(
keys_input
+
block_offset
,
keys
,
storage
.
load_keys
);
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
// Load selection flags into is_selected, generate them using
// input value and selection predicate, or generate them using
// block_discontinuity primitive
partition_block_load_flags
<
SelectMethod
,
block_size
,
block_load_flag_type
,
block_discontinuity_key_type
>
(
keys_input
+
block_offset
-
1
,
flags
+
block_offset
,
keys
,
is_selected
,
predicates
...,
inequality_op
,
storage
,
flat_block_id
,
flat_block_thread_id
,
is_last_block
,
valid_in_last_block
);
// Convert true/false is_selected flags to 0s and 1s
convert_selected_to_indices
(
output_indices
,
is_selected
);
// Number of selected values in previous blocks
offset_type
selected_prefix
{};
// Number of selected values in this block
offset_type
selected_in_block
{};
// Calculate number of selected values in block and their indices
if
(
flat_block_id
==
0
)
{
block_scan_offset_type
()
.
exclusive_scan
(
output_indices
,
output_indices
,
offset_type
{},
/** initial value */
selected_in_block
,
storage
.
scan_offsets
,
::
rocprim
::
plus
<
offset_type
>
()
);
if
(
flat_block_thread_id
==
0
)
{
offset_scan_state
.
set_complete
(
flat_block_id
,
selected_in_block
);
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
}
else
{
ROCPRIM_SHARED_MEMORY
typename
offset_scan_prefix_op_type
::
storage_type
storage_prefix_op
;
auto
prefix_op
=
offset_scan_prefix_op_type
(
flat_block_id
,
offset_scan_state
,
storage_prefix_op
);
block_scan_offset_type
()
.
exclusive_scan
(
output_indices
,
output_indices
,
storage
.
scan_offsets
,
prefix_op
,
::
rocprim
::
plus
<
offset_type
>
()
);
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
selected_in_block
=
prefix_op
.
get_reduction
();
selected_prefix
=
prefix_op
.
get_exclusive_prefix
();
}
// Scatter selected and rejected values
partition_scatter
<
OnlySelected
,
block_size
>
(
keys
,
is_selected
,
output_indices
,
keys_output
,
size
,
selected_prefix
,
selected_in_block
,
storage
.
exchange_keys
,
flat_block_id
,
flat_block_thread_id
,
is_last_block
,
valid_in_last_block
,
prev_selected_count
);
static
constexpr
bool
with_values
=
!
std
::
is_same
<
value_type
,
::
rocprim
::
empty_type
>::
value
;
if
ROCPRIM_IF_CONSTEXPR
(
with_values
)
{
value_type
values
[
items_per_thread
];
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
if
(
is_last_block
)
{
block_load_value_type
()
.
load
(
values_input
+
block_offset
,
values
,
valid_in_last_block
,
storage
.
load_values
);
}
else
{
block_load_value_type
()
.
load
(
values_input
+
block_offset
,
values
,
storage
.
load_values
);
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
partition_scatter
<
OnlySelected
,
block_size
>
(
values
,
is_selected
,
output_indices
,
values_output
,
size
,
selected_prefix
,
selected_in_block
,
storage
.
exchange_values
,
flat_block_id
,
flat_block_thread_id
,
is_last_block
,
valid_in_last_block
,
prev_selected_count
);
}
// Last block in grid stores number of selected values
if
(
is_last_block
&&
flat_block_thread_id
==
0
)
{
store_selected_count
(
selected_count
,
prev_selected_count
,
selected_prefix
,
selected_in_block
);
}
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_
3rdparty/cub/rocprim/device/detail/device_radix_sort.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../detail/radix_sort.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../types.hpp"
#include "../../block/block_discontinuity.hpp"
#include "../../block/block_exchange.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_load_func.hpp"
#include "../../block/block_scan.hpp"
#include "../../block/block_radix_sort.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
// Wrapping functions that allow to call proper methods (with or without values)
// (a variant with values is enabled only when Value is not empty_type)
template
<
bool
Descending
=
false
,
class
SortType
,
class
SortKey
,
class
SortValue
,
unsigned
int
ItemsPerThread
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
sort_block
(
SortType
sorter
,
SortKey
(
&
keys
)[
ItemsPerThread
],
SortValue
(
&
values
)[
ItemsPerThread
],
typename
SortType
::
storage_type
&
storage
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
)
{
if
(
Descending
)
{
sorter
.
sort_desc
(
keys
,
values
,
storage
,
begin_bit
,
end_bit
);
}
else
{
sorter
.
sort
(
keys
,
values
,
storage
,
begin_bit
,
end_bit
);
}
}
template
<
bool
Descending
=
false
,
class
SortType
,
class
SortKey
,
unsigned
int
ItemsPerThread
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
sort_block
(
SortType
sorter
,
SortKey
(
&
keys
)[
ItemsPerThread
],
::
rocprim
::
empty_type
(
&
values
)[
ItemsPerThread
],
typename
SortType
::
storage_type
&
storage
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
)
{
(
void
)
values
;
if
(
Descending
)
{
sorter
.
sort_desc
(
keys
,
storage
,
begin_bit
,
end_bit
);
}
else
{
sorter
.
sort
(
keys
,
storage
,
begin_bit
,
end_bit
);
}
}
template
<
unsigned
int
WarpSize
,
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
unsigned
int
RadixBits
,
bool
Descending
>
struct
radix_digit_count_helper
{
static
constexpr
unsigned
int
radix_size
=
1
<<
RadixBits
;
static
constexpr
unsigned
int
warp_size
=
WarpSize
;
static
constexpr
unsigned
int
warps_no
=
BlockSize
/
warp_size
;
static_assert
(
BlockSize
%
::
rocprim
::
device_warp_size
()
==
0
,
"BlockSize must be divisible by warp size"
);
static_assert
(
radix_size
<=
BlockSize
,
"Radix size must not exceed BlockSize"
);
struct
storage_type
{
unsigned
int
digit_counts
[
warps_no
][
radix_size
];
};
template
<
bool
IsFull
=
false
,
class
KeysInputIterator
,
class
Offset
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
count_digits
(
KeysInputIterator
keys_input
,
Offset
begin_offset
,
Offset
end_offset
,
unsigned
int
bit
,
unsigned
int
current_radix_bits
,
storage_type
&
storage
,
unsigned
int
&
digit_count
)
// i-th thread will get i-th digit's value
{
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
key_codec
=
radix_key_codec
<
key_type
,
Descending
>
;
using
bit_key_type
=
typename
key_codec
::
bit_key_type
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
warp_id
=
::
rocprim
::
warp_id
<
0
,
1
,
1
>
();
if
(
flat_id
<
radix_size
)
{
for
(
unsigned
int
w
=
0
;
w
<
warps_no
;
w
++
)
{
storage
.
digit_counts
[
w
][
flat_id
]
=
0
;
}
}
::
rocprim
::
syncthreads
();
for
(
Offset
block_offset
=
begin_offset
;
block_offset
<
end_offset
;
block_offset
+=
items_per_block
)
{
key_type
keys
[
ItemsPerThread
];
unsigned
int
valid_count
;
// Use loading into a striped arrangement because an order of items is irrelevant,
// only totals matter
if
(
IsFull
||
(
block_offset
+
items_per_block
<=
end_offset
))
{
valid_count
=
items_per_block
;
block_load_direct_striped
<
BlockSize
>
(
flat_id
,
keys_input
+
block_offset
,
keys
);
}
else
{
valid_count
=
end_offset
-
block_offset
;
block_load_direct_striped
<
BlockSize
>
(
flat_id
,
keys_input
+
block_offset
,
keys
,
valid_count
);
}
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
const
bit_key_type
bit_key
=
key_codec
::
encode
(
keys
[
i
]);
const
unsigned
int
digit
=
key_codec
::
extract_digit
(
bit_key
,
bit
,
current_radix_bits
);
const
unsigned
int
pos
=
i
*
BlockSize
+
flat_id
;
lane_mask_type
same_digit_lanes_mask
=
::
rocprim
::
ballot
(
IsFull
||
(
pos
<
valid_count
));
for
(
unsigned
int
b
=
0
;
b
<
RadixBits
;
b
++
)
{
const
unsigned
int
bit_set
=
digit
&
(
1u
<<
b
);
const
lane_mask_type
bit_set_mask
=
::
rocprim
::
ballot
(
bit_set
);
same_digit_lanes_mask
&=
(
bit_set
?
bit_set_mask
:
~
bit_set_mask
);
}
const
unsigned
int
same_digit_count
=
::
rocprim
::
bit_count
(
same_digit_lanes_mask
);
const
unsigned
int
prev_same_digit_count
=
::
rocprim
::
masked_bit_count
(
same_digit_lanes_mask
);
if
(
prev_same_digit_count
==
0
)
{
// Write the number of lanes having this digit,
// if the current lane is the first (and maybe only) lane with this digit.
storage
.
digit_counts
[
warp_id
][
digit
]
+=
same_digit_count
;
}
}
}
::
rocprim
::
syncthreads
();
digit_count
=
0
;
if
(
flat_id
<
radix_size
)
{
for
(
unsigned
int
w
=
0
;
w
<
warps_no
;
w
++
)
{
digit_count
+=
storage
.
digit_counts
[
w
][
flat_id
];
}
}
}
};
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
bool
Descending
,
class
Key
,
class
Value
>
struct
radix_sort_single_helper
{
static
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
using
key_type
=
Key
;
using
value_type
=
Value
;
using
key_codec
=
radix_key_codec
<
key_type
,
Descending
>
;
using
bit_key_type
=
typename
key_codec
::
bit_key_type
;
using
keys_load_type
=
::
rocprim
::
block_load
<
key_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
values_load_type
=
::
rocprim
::
block_load
<
value_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
sort_type
=
::
rocprim
::
block_radix_sort
<
key_type
,
BlockSize
,
ItemsPerThread
,
value_type
>
;
static
constexpr
bool
with_values
=
!
std
::
is_same
<
value_type
,
::
rocprim
::
empty_type
>::
value
;
struct
storage_type
{
union
{
typename
keys_load_type
::
storage_type
keys_load
;
typename
values_load_type
::
storage_type
values_load
;
typename
sort_type
::
storage_type
sort
;
};
};
template
<
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
sort_single
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
unsigned
int
size
,
unsigned
int
bit
,
unsigned
int
current_radix_bits
,
storage_type
&
storage
)
{
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
flat_block_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_block
;
const
unsigned
int
number_of_blocks
=
(
size
+
items_per_block
-
1
)
/
items_per_block
;
unsigned
int
valid_in_last_block
;
const
bool
last_block
=
flat_block_id
==
(
number_of_blocks
-
1
);
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
key_codec
=
radix_key_codec
<
key_type
,
Descending
>
;
using
bit_key_type
=
typename
key_codec
::
bit_key_type
;
key_type
keys
[
ItemsPerThread
];
value_type
values
[
ItemsPerThread
];
if
(
!
last_block
)
{
valid_in_last_block
=
items_per_block
;
keys_load_type
().
load
(
keys_input
+
block_offset
,
keys
,
storage
.
keys_load
);
if
(
with_values
)
{
::
rocprim
::
syncthreads
();
values_load_type
().
load
(
values_input
+
block_offset
,
values
,
storage
.
values_load
);
}
}
else
{
const
key_type
out_of_bounds
=
key_codec
::
decode
(
bit_key_type
(
-
1
));
valid_in_last_block
=
size
-
items_per_block
*
(
number_of_blocks
-
1
);
keys_load_type
().
load
(
keys_input
+
block_offset
,
keys
,
valid_in_last_block
,
out_of_bounds
,
storage
.
keys_load
);
if
(
with_values
)
{
::
rocprim
::
syncthreads
();
values_load_type
().
load
(
values_input
+
block_offset
,
values
,
valid_in_last_block
,
storage
.
values_load
);
}
}
::
rocprim
::
syncthreads
();
sort_block
<
Descending
>
(
sort_type
(),
keys
,
values
,
storage
.
sort
,
bit
,
bit
+
current_radix_bits
);
// Store keys and values
#pragma unroll
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
++
i
)
{
unsigned
int
item_offset
=
flat_id
*
ItemsPerThread
+
i
;
if
(
item_offset
<
valid_in_last_block
)
{
keys_output
[
block_offset
+
item_offset
]
=
keys
[
i
];
if
(
with_values
)
values_output
[
block_offset
+
item_offset
]
=
values
[
i
];
}
}
}
};
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
unsigned
int
RadixBits
,
bool
Descending
,
class
Key
,
class
Value
,
class
Offset
>
struct
radix_sort_and_scatter_helper
{
static
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
static
constexpr
unsigned
int
radix_size
=
1
<<
RadixBits
;
using
key_type
=
Key
;
using
value_type
=
Value
;
using
key_codec
=
radix_key_codec
<
key_type
,
Descending
>
;
using
bit_key_type
=
typename
key_codec
::
bit_key_type
;
using
keys_load_type
=
::
rocprim
::
block_load
<
key_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
values_load_type
=
::
rocprim
::
block_load
<
value_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
sort_type
=
::
rocprim
::
block_radix_sort
<
key_type
,
BlockSize
,
ItemsPerThread
,
value_type
>
;
using
discontinuity_type
=
::
rocprim
::
block_discontinuity
<
unsigned
int
,
BlockSize
>
;
using
bit_keys_exchange_type
=
::
rocprim
::
block_exchange
<
bit_key_type
,
BlockSize
,
ItemsPerThread
>
;
using
values_exchange_type
=
::
rocprim
::
block_exchange
<
value_type
,
BlockSize
,
ItemsPerThread
>
;
static
constexpr
bool
with_values
=
!
std
::
is_same
<
value_type
,
::
rocprim
::
empty_type
>::
value
;
struct
storage_type
{
union
{
typename
keys_load_type
::
storage_type
keys_load
;
typename
values_load_type
::
storage_type
values_load
;
typename
sort_type
::
storage_type
sort
;
typename
discontinuity_type
::
storage_type
discontinuity
;
typename
bit_keys_exchange_type
::
storage_type
bit_keys_exchange
;
typename
values_exchange_type
::
storage_type
values_exchange
;
};
unsigned
short
starts
[
radix_size
];
unsigned
short
ends
[
radix_size
];
Offset
digit_starts
[
radix_size
];
};
template
<
bool
IsFull
=
false
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
sort_and_scatter
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
Offset
begin_offset
,
Offset
end_offset
,
unsigned
int
bit
,
unsigned
int
current_radix_bits
,
Offset
digit_start
,
// i-th thread must pass i-th digit's value
storage_type
&
storage
)
{
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
if
(
flat_id
<
radix_size
)
{
storage
.
digit_starts
[
flat_id
]
=
digit_start
;
}
for
(
Offset
block_offset
=
begin_offset
;
block_offset
<
end_offset
;
block_offset
+=
items_per_block
)
{
key_type
keys
[
ItemsPerThread
];
value_type
values
[
ItemsPerThread
];
unsigned
int
valid_count
;
if
(
IsFull
||
(
block_offset
+
items_per_block
<=
end_offset
))
{
valid_count
=
items_per_block
;
keys_load_type
().
load
(
keys_input
+
block_offset
,
keys
,
storage
.
keys_load
);
if
(
with_values
)
{
::
rocprim
::
syncthreads
();
values_load_type
().
load
(
values_input
+
block_offset
,
values
,
storage
.
values_load
);
}
}
else
{
valid_count
=
end_offset
-
block_offset
;
// Sort will leave "invalid" (out of size) items at the end of the sorted sequence
const
key_type
out_of_bounds
=
key_codec
::
decode
(
bit_key_type
(
-
1
));
keys_load_type
().
load
(
keys_input
+
block_offset
,
keys
,
valid_count
,
out_of_bounds
,
storage
.
keys_load
);
if
(
with_values
)
{
::
rocprim
::
syncthreads
();
values_load_type
().
load
(
values_input
+
block_offset
,
values
,
valid_count
,
storage
.
values_load
);
}
}
if
(
flat_id
<
radix_size
)
{
storage
.
starts
[
flat_id
]
=
valid_count
;
storage
.
ends
[
flat_id
]
=
valid_count
;
}
::
rocprim
::
syncthreads
();
sort_block
<
Descending
>
(
sort_type
(),
keys
,
values
,
storage
.
sort
,
bit
,
bit
+
current_radix_bits
);
bit_key_type
bit_keys
[
ItemsPerThread
];
unsigned
int
digits
[
ItemsPerThread
];
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
bit_keys
[
i
]
=
key_codec
::
encode
(
keys
[
i
]);
digits
[
i
]
=
key_codec
::
extract_digit
(
bit_keys
[
i
],
bit
,
current_radix_bits
);
}
bool
head_flags
[
ItemsPerThread
];
bool
tail_flags
[
ItemsPerThread
];
::
rocprim
::
not_equal_to
<
unsigned
int
>
flag_op
;
::
rocprim
::
syncthreads
();
discontinuity_type
().
flag_heads_and_tails
(
head_flags
,
tail_flags
,
digits
,
flag_op
,
storage
.
discontinuity
);
// Fill start and end position of subsequence for every digit
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
const
unsigned
int
digit
=
digits
[
i
];
const
unsigned
int
pos
=
flat_id
*
ItemsPerThread
+
i
;
if
(
head_flags
[
i
])
{
storage
.
starts
[
digit
]
=
pos
;
}
if
(
tail_flags
[
i
])
{
storage
.
ends
[
digit
]
=
pos
;
}
}
::
rocprim
::
syncthreads
();
// Rearrange to striped arrangement to have faster coalesced writes instead of
// scattering of blocked-arranged items
bit_keys_exchange_type
().
blocked_to_striped
(
bit_keys
,
bit_keys
,
storage
.
bit_keys_exchange
);
if
(
with_values
)
{
::
rocprim
::
syncthreads
();
values_exchange_type
().
blocked_to_striped
(
values
,
values
,
storage
.
values_exchange
);
}
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
const
unsigned
int
digit
=
key_codec
::
extract_digit
(
bit_keys
[
i
],
bit
,
current_radix_bits
);
const
unsigned
int
pos
=
i
*
BlockSize
+
flat_id
;
if
(
IsFull
||
(
pos
<
valid_count
))
{
const
Offset
dst
=
pos
-
storage
.
starts
[
digit
]
+
storage
.
digit_starts
[
digit
];
keys_output
[
dst
]
=
key_codec
::
decode
(
bit_keys
[
i
]);
if
(
with_values
)
{
values_output
[
dst
]
=
values
[
i
];
}
}
}
::
rocprim
::
syncthreads
();
// Accumulate counts of the current block
if
(
flat_id
<
radix_size
)
{
const
unsigned
int
digit
=
flat_id
;
const
unsigned
int
start
=
storage
.
starts
[
digit
];
const
unsigned
int
end
=
storage
.
ends
[
digit
];
if
(
start
<
valid_count
)
{
storage
.
digit_starts
[
digit
]
+=
(
::
rocprim
::
min
(
valid_count
-
1
,
end
)
-
start
+
1
);
}
}
}
}
};
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
unsigned
int
RadixBits
,
bool
Descending
,
class
KeysInputIterator
,
class
Offset
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
fill_digit_counts
(
KeysInputIterator
keys_input
,
Offset
size
,
Offset
*
batch_digit_counts
,
unsigned
int
bit
,
unsigned
int
current_radix_bits
,
unsigned
int
blocks_per_full_batch
,
unsigned
int
full_batches
)
{
constexpr
unsigned
int
radix_size
=
1
<<
RadixBits
;
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
using
count_helper_type
=
radix_digit_count_helper
<::
rocprim
::
device_warp_size
(),
BlockSize
,
ItemsPerThread
,
RadixBits
,
Descending
>
;
ROCPRIM_SHARED_MEMORY
typename
count_helper_type
::
storage_type
storage
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
batch_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
Offset
block_offset
;
unsigned
int
blocks_per_batch
;
if
(
batch_id
<
full_batches
)
{
blocks_per_batch
=
blocks_per_full_batch
;
block_offset
=
batch_id
*
blocks_per_batch
;
}
else
{
blocks_per_batch
=
blocks_per_full_batch
-
1
;
block_offset
=
batch_id
*
blocks_per_batch
+
full_batches
;
}
block_offset
*=
items_per_block
;
unsigned
int
digit_count
;
if
(
batch_id
<
::
rocprim
::
detail
::
grid_size
<
0
>
()
-
1
)
{
count_helper_type
().
template
count_digits
<
true
>(
keys_input
,
block_offset
,
block_offset
+
blocks_per_batch
*
items_per_block
,
bit
,
current_radix_bits
,
storage
,
digit_count
);
}
else
{
count_helper_type
().
template
count_digits
<
false
>(
keys_input
,
block_offset
,
size
,
bit
,
current_radix_bits
,
storage
,
digit_count
);
}
if
(
flat_id
<
radix_size
)
{
batch_digit_counts
[
batch_id
*
radix_size
+
flat_id
]
=
digit_count
;
}
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
unsigned
int
RadixBits
,
class
Offset
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
scan_batches
(
Offset
*
batch_digit_counts
,
Offset
*
digit_counts
,
unsigned
int
batches
)
{
constexpr
unsigned
int
radix_size
=
1
<<
RadixBits
;
using
scan_type
=
typename
::
rocprim
::
block_scan
<
Offset
,
BlockSize
>
;
const
unsigned
int
digit
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
Offset
values
[
ItemsPerThread
];
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
const
unsigned
int
batch_id
=
flat_id
*
ItemsPerThread
+
i
;
values
[
i
]
=
(
batch_id
<
batches
?
batch_digit_counts
[
batch_id
*
radix_size
+
digit
]
:
0
);
}
Offset
digit_count
;
scan_type
().
exclusive_scan
(
values
,
values
,
0
,
digit_count
);
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
const
unsigned
int
batch_id
=
flat_id
*
ItemsPerThread
+
i
;
if
(
batch_id
<
batches
)
{
batch_digit_counts
[
batch_id
*
radix_size
+
digit
]
=
values
[
i
];
}
}
if
(
flat_id
==
0
)
{
digit_counts
[
digit
]
=
digit_count
;
}
}
template
<
unsigned
int
RadixBits
,
class
Offset
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
scan_digits
(
Offset
*
digit_counts
)
{
constexpr
unsigned
int
radix_size
=
1
<<
RadixBits
;
using
scan_type
=
typename
::
rocprim
::
block_scan
<
Offset
,
radix_size
>
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
Offset
value
=
digit_counts
[
flat_id
];
scan_type
().
exclusive_scan
(
value
,
value
,
0
);
digit_counts
[
flat_id
]
=
value
;
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
bool
Descending
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
sort_single
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
unsigned
int
size
,
unsigned
int
bit
,
unsigned
int
current_radix_bits
)
{
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
value_type
=
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
;
using
sort_single_helper
=
radix_sort_single_helper
<
BlockSize
,
ItemsPerThread
,
Descending
,
key_type
,
value_type
>
;
ROCPRIM_SHARED_MEMORY
typename
sort_single_helper
::
storage_type
storage
;
sort_single_helper
().
template
sort_single
(
keys_input
,
keys_output
,
values_input
,
values_output
,
size
,
bit
,
current_radix_bits
,
storage
);
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
unsigned
int
RadixBits
,
bool
Descending
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
,
class
Offset
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
sort_and_scatter
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
Offset
size
,
const
Offset
*
batch_digit_starts
,
const
Offset
*
digit_starts
,
unsigned
int
bit
,
unsigned
int
current_radix_bits
,
unsigned
int
blocks_per_full_batch
,
unsigned
int
full_batches
)
{
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
constexpr
unsigned
int
radix_size
=
1
<<
RadixBits
;
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
value_type
=
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
;
using
sort_and_scatter_helper
=
radix_sort_and_scatter_helper
<
BlockSize
,
ItemsPerThread
,
RadixBits
,
Descending
,
key_type
,
value_type
,
Offset
>
;
ROCPRIM_SHARED_MEMORY
typename
sort_and_scatter_helper
::
storage_type
storage
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
batch_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
Offset
block_offset
;
unsigned
int
blocks_per_batch
;
if
(
batch_id
<
full_batches
)
{
blocks_per_batch
=
blocks_per_full_batch
;
block_offset
=
batch_id
*
blocks_per_batch
;
}
else
{
blocks_per_batch
=
blocks_per_full_batch
-
1
;
block_offset
=
batch_id
*
blocks_per_batch
+
full_batches
;
}
block_offset
*=
items_per_block
;
Offset
digit_start
=
0
;
if
(
flat_id
<
radix_size
)
{
digit_start
=
digit_starts
[
flat_id
]
+
batch_digit_starts
[
batch_id
*
radix_size
+
flat_id
];
}
if
(
batch_id
<
::
rocprim
::
detail
::
grid_size
<
0
>
()
-
1
)
{
sort_and_scatter_helper
().
template
sort_and_scatter
<
true
>(
keys_input
,
keys_output
,
values_input
,
values_output
,
block_offset
,
block_offset
+
blocks_per_batch
*
items_per_block
,
bit
,
current_radix_bits
,
digit_start
,
storage
);
}
else
{
sort_and_scatter_helper
().
template
sort_and_scatter
<
false
>(
keys_input
,
keys_output
,
values_input
,
values_output
,
block_offset
,
size
,
bit
,
current_radix_bits
,
digit_start
,
storage
);
}
}
template
<
bool
WithValues
,
class
KeysInputIterator
,
class
ValuesInputIterator
,
class
Key
,
class
Value
,
unsigned
int
ItemsPerThread
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
typename
std
::
enable_if
<!
WithValues
>::
type
block_load_radix_impl
(
const
unsigned
int
flat_id
,
const
unsigned
int
block_offset
,
const
unsigned
int
valid_in_last_block
,
const
bool
last_block
,
KeysInputIterator
keys_input
,
ValuesInputIterator
values_input
,
Key
(
&
keys
)[
ItemsPerThread
],
Value
(
&
values
)[
ItemsPerThread
])
{
(
void
)
values_input
;
(
void
)
values
;
if
(
last_block
)
{
block_load_direct_blocked
(
flat_id
,
keys_input
+
block_offset
,
keys
,
valid_in_last_block
);
}
else
{
block_load_direct_blocked
(
flat_id
,
keys_input
+
block_offset
,
keys
);
}
}
template
<
bool
WithValues
,
class
KeysInputIterator
,
class
ValuesInputIterator
,
class
Key
,
class
Value
,
unsigned
int
ItemsPerThread
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
typename
std
::
enable_if
<
WithValues
>::
type
block_load_radix_impl
(
const
unsigned
int
flat_id
,
const
unsigned
int
block_offset
,
const
unsigned
int
valid_in_last_block
,
const
bool
last_block
,
KeysInputIterator
keys_input
,
ValuesInputIterator
values_input
,
Key
(
&
keys
)[
ItemsPerThread
],
Value
(
&
values
)[
ItemsPerThread
])
{
if
(
last_block
)
{
block_load_direct_blocked
(
flat_id
,
keys_input
+
block_offset
,
keys
,
valid_in_last_block
);
block_load_direct_blocked
(
flat_id
,
values_input
+
block_offset
,
values
,
valid_in_last_block
);
}
else
{
block_load_direct_blocked
(
flat_id
,
keys_input
+
block_offset
,
keys
);
block_load_direct_blocked
(
flat_id
,
values_input
+
block_offset
,
values
);
}
}
template
<
class
T
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
compare_nan_sensitive
(
const
T
&
a
,
const
T
&
b
)
->
typename
std
::
enable_if
<
rocprim
::
is_floating_point
<
T
>::
value
,
bool
>::
type
{
// Beware: the performance of this function is extremely vulnerable to refactoring.
// Always check benchmark_device_segmented_radix_sort and benchmark_device_radix_sort
// when making changes to this function.
using
bit_key_type
=
typename
float_bit_mask
<
T
>::
bit_type
;
static
constexpr
auto
sign_bit
=
float_bit_mask
<
T
>::
sign_bit
;
auto
a_bits
=
__builtin_bit_cast
(
bit_key_type
,
a
);
auto
b_bits
=
__builtin_bit_cast
(
bit_key_type
,
b
);
// convert -0.0 to +0.0
a_bits
=
a_bits
==
sign_bit
?
0
:
a_bits
;
b_bits
=
b_bits
==
sign_bit
?
0
:
b_bits
;
// invert negatives, put 1 into sign bit for positives
a_bits
^=
(
sign_bit
&
a_bits
)
==
0
?
sign_bit
:
bit_key_type
(
-
1
);
b_bits
^=
(
sign_bit
&
b_bits
)
==
0
?
sign_bit
:
bit_key_type
(
-
1
);
// sort numbers and NaNs according to their bit representation
return
a_bits
>
b_bits
;
}
template
<
class
T
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
compare_nan_sensitive
(
const
T
&
a
,
const
T
&
b
)
->
typename
std
::
enable_if
<!
rocprim
::
is_floating_point
<
T
>::
value
,
bool
>::
type
{
return
a
>
b
;
}
template
<
bool
Descending
,
bool
UseRadixMask
,
class
T
,
class
Enable
=
void
>
struct
radix_merge_compare
;
template
<
class
T
>
struct
radix_merge_compare
<
false
,
false
,
T
>
{
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
compare_nan_sensitive
<
T
>
(
b
,
a
);
}
};
template
<
class
T
>
struct
radix_merge_compare
<
true
,
false
,
T
>
{
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
compare_nan_sensitive
<
T
>
(
a
,
b
);
}
};
template
<
class
T
>
struct
radix_merge_compare
<
false
,
true
,
T
,
typename
std
::
enable_if
<
rocprim
::
is_integral
<
T
>::
value
>::
type
>
{
T
radix_mask
;
ROCPRIM_HOST_DEVICE
ROCPRIM_INLINE
radix_merge_compare
(
const
unsigned
int
start_bit
,
const
unsigned
int
current_radix_bits
)
{
T
radix_mask_upper
=
(
T
(
1
)
<<
(
current_radix_bits
+
start_bit
))
-
1
;
T
radix_mask_bottom
=
(
T
(
1
)
<<
start_bit
)
-
1
;
radix_mask
=
radix_mask_upper
^
radix_mask_bottom
;
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
const
T
masked_key_a
=
a
&
radix_mask
;
const
T
masked_key_b
=
b
&
radix_mask
;
return
masked_key_b
>
masked_key_a
;
}
};
template
<
class
T
>
struct
radix_merge_compare
<
true
,
true
,
T
,
typename
std
::
enable_if
<
rocprim
::
is_integral
<
T
>::
value
>::
type
>
{
T
radix_mask
;
ROCPRIM_HOST_DEVICE
ROCPRIM_INLINE
radix_merge_compare
(
const
unsigned
int
start_bit
,
const
unsigned
int
current_radix_bits
)
{
T
radix_mask_upper
=
(
T
(
1
)
<<
(
current_radix_bits
+
start_bit
))
-
1
;
T
radix_mask_bottom
=
(
T
(
1
)
<<
start_bit
)
-
1
;
radix_mask
=
(
radix_mask_upper
^
radix_mask_bottom
);
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
const
T
masked_key_a
=
a
&
radix_mask
;
const
T
masked_key_b
=
b
&
radix_mask
;
return
masked_key_a
>
masked_key_b
;
}
};
template
<
bool
Descending
,
class
T
>
struct
radix_merge_compare
<
Descending
,
true
,
T
,
typename
std
::
enable_if
<!
rocprim
::
is_integral
<
T
>::
value
>::
type
>
{
// radix_merge_compare supports masks only for integrals.
// even though masks are never used for floating point-types,
// it needs to be able to compile.
ROCPRIM_HOST_DEVICE
ROCPRIM_INLINE
radix_merge_compare
(
const
unsigned
int
,
const
unsigned
int
){}
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
operator
()(
const
T
&
,
const
T
&
)
const
{
return
false
;
}
};
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
radix_block_merge_impl
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
const
size_t
input_size
,
const
unsigned
int
merge_items_per_block_size
,
BinaryFunction
compare_function
)
{
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
value_type
=
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
;
constexpr
bool
with_values
=
!
std
::
is_same
<
value_type
,
::
rocprim
::
empty_type
>::
value
;
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
flat_block_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_block
;
const
unsigned
int
number_of_blocks
=
(
input_size
+
items_per_block
-
1
)
/
items_per_block
;
const
bool
last_block
=
flat_block_id
==
(
number_of_blocks
-
1
);
auto
valid_in_last_block
=
last_block
?
input_size
-
items_per_block
*
(
number_of_blocks
-
1
)
:
items_per_block
;
unsigned
int
start_id
=
(
flat_block_id
*
items_per_block
)
+
flat_id
*
ItemsPerThread
;
if
(
start_id
>=
input_size
)
{
return
;
}
key_type
keys
[
ItemsPerThread
];
value_type
values
[
ItemsPerThread
];
block_load_radix_impl
<
with_values
>
(
flat_id
,
block_offset
,
valid_in_last_block
,
last_block
,
keys_input
,
values_input
,
keys
,
values
);
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
if
(
flat_id
*
ItemsPerThread
+
i
<
valid_in_last_block
)
{
const
unsigned
int
id
=
start_id
+
i
;
const
unsigned
int
block_id
=
id
/
merge_items_per_block_size
;
const
bool
block_id_is_odd
=
block_id
&
1
;
const
unsigned
int
next_block_id
=
block_id_is_odd
?
block_id
-
1
:
block_id
+
1
;
const
unsigned
int
block_start
=
min
(
block_id
*
merge_items_per_block_size
,
(
unsigned
int
)
input_size
);
const
unsigned
int
next_block_start
=
min
(
next_block_id
*
merge_items_per_block_size
,
(
unsigned
int
)
input_size
);
const
unsigned
int
next_block_end
=
min
((
next_block_id
+
1
)
*
merge_items_per_block_size
,
(
unsigned
int
)
input_size
);
if
(
next_block_start
==
input_size
)
{
keys_output
[
id
]
=
keys
[
i
];
if
(
with_values
)
{
values_output
[
id
]
=
values
[
i
];
}
}
unsigned
int
left_id
=
next_block_start
;
unsigned
int
right_id
=
next_block_end
;
while
(
left_id
<
right_id
)
{
unsigned
int
mid_id
=
(
left_id
+
right_id
)
/
2
;
key_type
mid_key
=
keys_input
[
mid_id
];
bool
smaller
=
compare_function
(
mid_key
,
keys
[
i
]);
left_id
=
smaller
?
mid_id
+
1
:
left_id
;
right_id
=
smaller
?
right_id
:
mid_id
;
}
right_id
=
next_block_end
;
if
(
block_id_is_odd
&&
left_id
!=
right_id
)
{
key_type
upper_key
=
keys_input
[
left_id
];
while
(
!
compare_function
(
upper_key
,
keys
[
i
])
&&
!
compare_function
(
keys
[
i
],
upper_key
)
&&
left_id
<
right_id
)
{
unsigned
int
mid_id
=
(
left_id
+
right_id
)
/
2
;
key_type
mid_key
=
keys_input
[
mid_id
];
bool
equal
=
!
compare_function
(
mid_key
,
keys
[
i
])
&&
!
compare_function
(
keys
[
i
],
mid_key
);
left_id
=
equal
?
mid_id
+
1
:
left_id
+
1
;
right_id
=
equal
?
right_id
:
mid_id
;
upper_key
=
keys_input
[
left_id
];
}
}
unsigned
int
offset
=
0
;
offset
+=
id
-
block_start
;
offset
+=
left_id
-
next_block_start
;
offset
+=
min
(
block_start
,
next_block_start
);
keys_output
[
offset
]
=
keys
[
i
];
if
(
with_values
)
{
values_output
[
offset
]
=
values
[
i
];
}
}
}
}
}
// end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_
3rdparty/cub/rocprim/device/detail/device_reduce.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_HPP_
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../types.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_reduce.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
// Helper functions for reducing final value with
// initial value.
template
<
bool
WithInitialValue
,
class
T
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
reduce_with_initial
(
T
output
,
T
initial_value
,
BinaryFunction
reduce_op
)
->
typename
std
::
enable_if
<
WithInitialValue
,
T
>::
type
{
return
reduce_op
(
initial_value
,
output
);
}
template
<
bool
WithInitialValue
,
class
T
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
reduce_with_initial
(
T
output
,
T
initial_value
,
BinaryFunction
reduce_op
)
->
typename
std
::
enable_if
<!
WithInitialValue
,
T
>::
type
{
(
void
)
initial_value
;
(
void
)
reduce_op
;
return
output
;
}
template
<
bool
WithInitialValue
,
class
Config
,
class
ResultType
,
class
InputIterator
,
class
OutputIterator
,
class
InitValueType
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
block_reduce_kernel_impl
(
InputIterator
input
,
const
size_t
input_size
,
OutputIterator
output
,
InitValueType
initial_value
,
BinaryFunction
reduce_op
)
{
constexpr
unsigned
int
block_size
=
Config
::
block_size
;
constexpr
unsigned
int
items_per_thread
=
Config
::
items_per_thread
;
using
result_type
=
ResultType
;
using
block_reduce_type
=
::
rocprim
::
block_reduce
<
result_type
,
block_size
,
Config
::
block_reduce_method
>
;
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
flat_block_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_block
;
const
unsigned
int
number_of_blocks
=
::
rocprim
::
detail
::
grid_size
<
0
>
();
auto
valid_in_last_block
=
input_size
-
items_per_block
*
(
number_of_blocks
-
1
);
result_type
values
[
items_per_thread
];
result_type
output_value
;
if
(
flat_block_id
==
(
number_of_blocks
-
1
))
// last block
{
block_load_direct_striped
<
block_size
>
(
flat_id
,
input
+
block_offset
,
values
,
valid_in_last_block
);
output_value
=
values
[
0
];
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
1
;
i
<
items_per_thread
;
i
++
)
{
unsigned
int
offset
=
i
*
block_size
;
if
(
flat_id
+
offset
<
valid_in_last_block
)
{
output_value
=
reduce_op
(
output_value
,
values
[
i
]);
}
}
block_reduce_type
()
.
reduce
(
output_value
,
// input
output_value
,
// output
valid_in_last_block
,
reduce_op
);
}
else
{
block_load_direct_striped
<
block_size
>
(
flat_id
,
input
+
block_offset
,
values
);
// load input values into values
block_reduce_type
()
.
reduce
(
values
,
// input
output_value
,
// output
reduce_op
);
}
// Save value into output
if
(
flat_id
==
0
)
{
output
[
flat_block_id
]
=
input_size
==
0
?
static_cast
<
result_type
>
(
initial_value
)
:
reduce_with_initial
<
WithInitialValue
>
(
output_value
,
static_cast
<
result_type
>
(
initial_value
),
reduce_op
);
}
}
// Returns size of temporary storage in bytes.
template
<
class
T
>
size_t
reduce_get_temporary_storage_bytes
(
size_t
input_size
,
size_t
items_per_block
)
{
if
(
input_size
<=
items_per_block
)
{
return
0
;
}
auto
size
=
(
input_size
+
items_per_block
-
1
)
/
(
items_per_block
);
return
size
*
sizeof
(
T
)
+
reduce_get_temporary_storage_bytes
<
T
>
(
size
,
items_per_block
);
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_HPP_
3rdparty/cub/rocprim/device/detail/device_reduce_by_key.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_BY_KEY_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_BY_KEY_HPP_
#include <iterator>
#include <utility>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../block/block_discontinuity.hpp"
#include "../../block/block_load_func.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_store.hpp"
#include "../../block/block_scan.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
template
<
class
Value
>
struct
carry_out
{
ROCPRIM_DEVICE
ROCPRIM_INLINE
carry_out
()
=
default
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
carry_out
(
const
carry_out
&
rhs
)
=
default
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
carry_out
&
operator
=
(
const
carry_out
&
rhs
)
{
value
=
rhs
.
value
;
destination
=
rhs
.
destination
;
next_has_carry_in
=
rhs
.
next_has_carry_in
;
return
*
this
;
}
Value
value
;
// carry-out of the current batch
unsigned
int
destination
;
bool
next_has_carry_in
;
// the next batch has carry-in (i.e. it continues the last segment from the current batch)
};
template
<
class
Value
>
struct
scan_by_key_pair
{
ROCPRIM_DEVICE
ROCPRIM_INLINE
scan_by_key_pair
()
=
default
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
scan_by_key_pair
(
const
scan_by_key_pair
&
rhs
)
=
default
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
scan_by_key_pair
&
operator
=
(
const
scan_by_key_pair
&
rhs
)
{
key
=
rhs
.
key
;
value
=
rhs
.
value
;
return
*
this
;
}
unsigned
int
key
;
Value
value
;
};
// Special operator which allows to calculate scan-by-key using block_scan.
// block_scan supports non-commutative scan operators.
// Initial values of pairs' keys must be 1 for the first item (head) of segment and 0 otherwise.
// As a result key contains the current segment's index and value contains segmented scan result.
template
<
class
Pair
,
class
BinaryFunction
>
struct
scan_by_key_op
{
BinaryFunction
reduce_op
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
scan_by_key_op
(
BinaryFunction
reduce_op
)
:
reduce_op
(
reduce_op
)
{}
ROCPRIM_DEVICE
ROCPRIM_INLINE
Pair
operator
()(
const
Pair
&
a
,
const
Pair
&
b
)
{
Pair
c
;
c
.
key
=
a
.
key
+
b
.
key
;
c
.
value
=
b
.
key
!=
0
?
b
.
value
:
reduce_op
(
a
.
value
,
b
.
value
);
return
c
;
}
};
// Wrappers that reverse results of key comparizon functions to use them as flag_op of block_discontinuity
// (for example, equal_to will work as not_equal_to and divide items into segments by keys)
template
<
class
Key
,
class
KeyCompareFunction
>
struct
key_flag_op
{
KeyCompareFunction
key_compare_op
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
key_flag_op
(
KeyCompareFunction
key_compare_op
)
:
key_compare_op
(
key_compare_op
)
{}
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
operator
()(
const
Key
&
a
,
const
Key
&
b
)
{
return
!
key_compare_op
(
a
,
b
);
}
};
// This wrapper processes only part of items and flags (valid_count - 1)th item (for tails)
// and (valid_count)th item (for heads), all items after valid_count are unflagged.
template
<
class
Key
,
class
KeyCompareFunction
>
struct
guarded_key_flag_op
{
KeyCompareFunction
key_compare_op
;
unsigned
int
valid_count
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
guarded_key_flag_op
(
KeyCompareFunction
key_compare_op
,
unsigned
int
valid_count
)
:
key_compare_op
(
key_compare_op
),
valid_count
(
valid_count
)
{}
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
operator
()(
const
Key
&
a
,
const
Key
&
b
,
unsigned
int
b_index
)
{
return
(
b_index
<
valid_count
&&
!
key_compare_op
(
a
,
b
))
||
b_index
==
valid_count
;
}
};
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
KeysInputIterator
,
class
KeyCompareFunction
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
fill_unique_counts
(
KeysInputIterator
keys_input
,
unsigned
int
size
,
unsigned
int
*
unique_counts
,
KeyCompareFunction
key_compare_op
,
unsigned
int
blocks_per_full_batch
,
unsigned
int
full_batches
)
{
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
constexpr
unsigned
int
warp_size
=
::
rocprim
::
device_warp_size
();
constexpr
unsigned
int
warps_no
=
BlockSize
/
warp_size
;
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
keys_load_type
=
::
rocprim
::
block_load
<
key_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
discontinuity_type
=
::
rocprim
::
block_discontinuity
<
key_type
,
BlockSize
>
;
ROCPRIM_SHARED_MEMORY
struct
{
union
{
typename
keys_load_type
::
storage_type
keys_load
;
typename
discontinuity_type
::
storage_type
discontinuity
;
};
unsigned
int
unique_counts
[
warps_no
];
}
storage
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
batch_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
lane_id
=
::
rocprim
::
lane_id
();
const
unsigned
int
warp_id
=
::
rocprim
::
warp_id
<
0
,
1
,
1
>
();
unsigned
int
block_offset
;
unsigned
int
blocks_per_batch
;
if
(
batch_id
<
full_batches
)
{
blocks_per_batch
=
blocks_per_full_batch
;
block_offset
=
batch_id
*
blocks_per_batch
;
}
else
{
blocks_per_batch
=
blocks_per_full_batch
-
1
;
block_offset
=
batch_id
*
blocks_per_batch
+
full_batches
;
}
block_offset
*=
items_per_block
;
unsigned
int
warp_unique_count
=
0
;
for
(
unsigned
int
bi
=
0
;
bi
<
blocks_per_batch
;
bi
++
)
{
const
bool
is_last_block
=
(
block_offset
+
items_per_block
>=
size
);
key_type
keys
[
ItemsPerThread
];
unsigned
int
valid_count
;
::
rocprim
::
syncthreads
();
if
(
is_last_block
)
{
valid_count
=
size
-
block_offset
;
keys_load_type
().
load
(
keys_input
+
block_offset
,
keys
,
valid_count
,
storage
.
keys_load
);
}
else
{
valid_count
=
items_per_block
;
keys_load_type
().
load
(
keys_input
+
block_offset
,
keys
,
storage
.
keys_load
);
}
bool
tail_flags
[
ItemsPerThread
];
key_type
successor_key
=
keys
[
ItemsPerThread
-
1
];
::
rocprim
::
syncthreads
();
if
(
is_last_block
)
{
discontinuity_type
().
flag_tails
(
tail_flags
,
successor_key
,
keys
,
guarded_key_flag_op
<
key_type
,
KeyCompareFunction
>
(
key_compare_op
,
valid_count
),
storage
.
discontinuity
);
}
else
{
if
(
flat_id
==
BlockSize
-
1
)
{
successor_key
=
keys_input
[
block_offset
+
items_per_block
];
}
discontinuity_type
().
flag_tails
(
tail_flags
,
successor_key
,
keys
,
key_flag_op
<
key_type
,
KeyCompareFunction
>
(
key_compare_op
),
storage
.
discontinuity
);
}
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
warp_unique_count
+=
::
rocprim
::
bit_count
(
::
rocprim
::
ballot
(
tail_flags
[
i
]));
}
block_offset
+=
items_per_block
;
}
if
(
lane_id
==
0
)
{
storage
.
unique_counts
[
warp_id
]
=
warp_unique_count
;
}
::
rocprim
::
syncthreads
();
if
(
flat_id
==
0
)
{
unsigned
int
batch_unique_count
=
0
;
for
(
unsigned
int
w
=
0
;
w
<
warps_no
;
w
++
)
{
batch_unique_count
+=
storage
.
unique_counts
[
w
];
}
unique_counts
[
batch_id
]
=
batch_unique_count
;
}
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
UniqueCountOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
scan_unique_counts
(
unsigned
int
*
unique_counts
,
UniqueCountOutputIterator
unique_count_output
,
unsigned
int
batches
)
{
using
load_type
=
::
rocprim
::
block_load
<
unsigned
int
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
store_type
=
::
rocprim
::
block_store
<
unsigned
int
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_store_method
::
block_store_transpose
>
;
using
scan_type
=
typename
::
rocprim
::
block_scan
<
unsigned
int
,
BlockSize
>
;
ROCPRIM_SHARED_MEMORY
union
{
typename
load_type
::
storage_type
load
;
typename
store_type
::
storage_type
store
;
typename
scan_type
::
storage_type
scan
;
}
storage
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
unsigned
int
values
[
ItemsPerThread
];
load_type
().
load
(
unique_counts
,
values
,
batches
,
0
,
storage
.
load
);
unsigned
int
unique_count
;
::
rocprim
::
syncthreads
();
scan_type
().
exclusive_scan
(
values
,
values
,
0
,
unique_count
);
::
rocprim
::
syncthreads
();
store_type
().
store
(
unique_counts
,
values
,
batches
,
storage
.
store
);
if
(
flat_id
==
0
)
{
*
unique_count_output
=
unique_count
;
}
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
KeysInputIterator
,
class
ValuesInputIterator
,
class
Result
,
class
UniqueOutputIterator
,
class
AggregatesOutputIterator
,
class
KeyCompareFunction
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
reduce_by_key
(
KeysInputIterator
keys_input
,
ValuesInputIterator
values_input
,
unsigned
int
size
,
const
unsigned
int
*
unique_starts
,
carry_out
<
Result
>
*
carry_outs
,
Result
*
leading_aggregates
,
UniqueOutputIterator
unique_output
,
AggregatesOutputIterator
aggregates_output
,
KeyCompareFunction
key_compare_op
,
BinaryFunction
reduce_op
,
unsigned
int
blocks_per_full_batch
,
unsigned
int
full_batches
)
{
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
result_type
=
Result
;
using
keys_load_type
=
::
rocprim
::
block_load
<
key_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
values_load_type
=
::
rocprim
::
block_load
<
result_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
discontinuity_type
=
::
rocprim
::
block_discontinuity
<
key_type
,
BlockSize
>
;
using
scan_type
=
::
rocprim
::
block_scan
<
scan_by_key_pair
<
result_type
>
,
BlockSize
>
;
ROCPRIM_SHARED_MEMORY
struct
{
union
{
typename
keys_load_type
::
storage_type
keys_load
;
typename
values_load_type
::
storage_type
values_load
;
typename
discontinuity_type
::
storage_type
discontinuity
;
typename
scan_type
::
storage_type
scan
;
};
unsigned
int
unique_count
;
bool
has_carry_in
;
detail
::
raw_storage
<
result_type
>
carry_in
;
}
storage
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
batch_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
unsigned
int
block_offset
;
unsigned
int
blocks_per_batch
;
if
(
batch_id
<
full_batches
)
{
blocks_per_batch
=
blocks_per_full_batch
;
block_offset
=
batch_id
*
blocks_per_batch
;
}
else
{
blocks_per_batch
=
blocks_per_full_batch
-
1
;
block_offset
=
batch_id
*
blocks_per_batch
+
full_batches
;
}
block_offset
*=
items_per_block
;
const
unsigned
int
batch_start
=
unique_starts
[
batch_id
];
unsigned
int
block_start
=
batch_start
;
if
(
flat_id
==
0
)
{
storage
.
has_carry_in
=
(
block_offset
>
0
)
&&
key_compare_op
(
keys_input
[
block_offset
-
1
],
keys_input
[
block_offset
]);
}
for
(
unsigned
int
bi
=
0
;
bi
<
blocks_per_batch
;
bi
++
)
{
const
bool
is_last_block
=
(
block_offset
+
items_per_block
>=
size
);
key_type
keys
[
ItemsPerThread
];
result_type
values
[
ItemsPerThread
];
unsigned
int
valid_count
;
if
(
is_last_block
)
{
valid_count
=
size
-
block_offset
;
keys_load_type
().
load
(
keys_input
+
block_offset
,
keys
,
valid_count
,
storage
.
keys_load
);
::
rocprim
::
syncthreads
();
values_load_type
().
load
(
values_input
+
block_offset
,
values
,
valid_count
,
storage
.
values_load
);
}
else
{
valid_count
=
items_per_block
;
keys_load_type
().
load
(
keys_input
+
block_offset
,
keys
,
storage
.
keys_load
);
::
rocprim
::
syncthreads
();
values_load_type
().
load
(
values_input
+
block_offset
,
values
,
storage
.
values_load
);
}
if
(
bi
>
0
&&
flat_id
==
0
&&
storage
.
has_carry_in
)
{
// Apply carry-out of the previous block as carry-in for the first segment
values
[
0
]
=
reduce_op
(
storage
.
carry_in
.
get
(),
values
[
0
]);
}
bool
head_flags
[
ItemsPerThread
];
bool
tail_flags
[
ItemsPerThread
];
key_type
successor_key
=
keys
[
ItemsPerThread
-
1
];
::
rocprim
::
syncthreads
();
if
(
is_last_block
)
{
discontinuity_type
().
flag_heads_and_tails
(
head_flags
,
tail_flags
,
successor_key
,
keys
,
guarded_key_flag_op
<
key_type
,
KeyCompareFunction
>
(
key_compare_op
,
valid_count
),
storage
.
discontinuity
);
}
else
{
if
(
flat_id
==
BlockSize
-
1
)
{
successor_key
=
keys_input
[
block_offset
+
items_per_block
];
}
discontinuity_type
().
flag_heads_and_tails
(
head_flags
,
tail_flags
,
successor_key
,
keys
,
key_flag_op
<
key_type
,
KeyCompareFunction
>
(
key_compare_op
),
storage
.
discontinuity
);
}
// Build pairs and run non-commutative inclusive scan to calculate scan-by-key
// and indices (ranks) of each segment:
// input:
// keys | 1 1 1 2 3 3 4 4 |
// head_flags | + + + + |
// values | 2 0 1 4 2 3 1 5 |
// result:
// scan values | 2 2 3 4 2 5 1 6 |
// scan keys | 1 1 1 2 3 3 4 4 |
// ranks (key-1) | 0 0 0 1 2 2 3 3 |
scan_by_key_pair
<
result_type
>
pairs
[
ItemsPerThread
];
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
pairs
[
i
].
key
=
head_flags
[
i
];
pairs
[
i
].
value
=
values
[
i
];
}
scan_by_key_op
<
scan_by_key_pair
<
result_type
>
,
BinaryFunction
>
scan_op
(
reduce_op
);
::
rocprim
::
syncthreads
();
scan_type
().
inclusive_scan
(
pairs
,
pairs
,
storage
.
scan
,
scan_op
);
unsigned
int
ranks
[
ItemsPerThread
];
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
ranks
[
i
]
=
pairs
[
i
].
key
-
1
;
// The first item is always flagged as head, so indices start from 1
values
[
i
]
=
pairs
[
i
].
value
;
}
if
(
flat_id
==
BlockSize
-
1
)
{
storage
.
unique_count
=
ranks
[
ItemsPerThread
-
1
]
+
(
tail_flags
[
ItemsPerThread
-
1
]
?
1
:
0
);
}
::
rocprim
::
syncthreads
();
const
unsigned
int
unique_count
=
storage
.
unique_count
;
if
(
flat_id
==
0
)
{
// The first item must be written only if it is the first item of the current segment
// (otherwise it is written by one of previous blocks)
head_flags
[
0
]
=
!
storage
.
has_carry_in
;
}
if
(
is_last_block
)
{
// Unflag the head after the last segment as it will be written out of bounds
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
if
(
ranks
[
i
]
>=
unique_count
)
{
head_flags
[
i
]
=
false
;
}
}
}
::
rocprim
::
syncthreads
();
if
(
flat_id
==
BlockSize
-
1
)
{
if
(
bi
==
blocks_per_batch
-
1
)
{
// Save carry-out of the last block of the current batch
carry_outs
[
batch_id
].
value
=
values
[
ItemsPerThread
-
1
];
carry_outs
[
batch_id
].
destination
=
block_start
+
ranks
[
ItemsPerThread
-
1
];
carry_outs
[
batch_id
].
next_has_carry_in
=
!
tail_flags
[
ItemsPerThread
-
1
];
}
else
{
// Save carry-out to use it as carry-in for the next block of the current batch
storage
.
has_carry_in
=
!
tail_flags
[
ItemsPerThread
-
1
];
storage
.
carry_in
.
get
()
=
values
[
ItemsPerThread
-
1
];
}
}
if
(
batch_id
>
0
&&
block_start
==
batch_start
)
{
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
// Write the scanned value of the last item of the first segment of the current batch
// (the leading possible incomplete aggregate) to calculate the final aggregate in the next kernel.
// The intermediate array is used instead of aggregates_output because
// aggregates_output may be write-only.
if
(
tail_flags
[
i
]
&&
ranks
[
i
]
==
0
)
{
leading_aggregates
[
batch_id
-
1
]
=
values
[
i
];
}
}
}
// Save unique keys and aggregates (some aggregates contains partial values
// and will be updated later by calculating scan-by-key of carry-outs)
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
if
(
head_flags
[
i
])
{
// Write the key of the first item of the segment as a unique key
unique_output
[
block_start
+
ranks
[
i
]]
=
keys
[
i
];
}
if
(
tail_flags
[
i
])
{
// Write the scanned value of the last item of the segment as an aggregate (reduction of the segment)
aggregates_output
[
block_start
+
ranks
[
i
]]
=
values
[
i
];
}
}
block_offset
+=
items_per_block
;
block_start
+=
unique_count
;
}
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
Result
,
class
AggregatesOutputIterator
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
scan_and_scatter_carry_outs
(
const
carry_out
<
Result
>
*
carry_outs
,
const
Result
*
leading_aggregates
,
AggregatesOutputIterator
aggregates_output
,
BinaryFunction
reduce_op
,
unsigned
int
batches
)
{
using
result_type
=
Result
;
using
discontinuity_type
=
::
rocprim
::
block_discontinuity
<
unsigned
int
,
BlockSize
>
;
using
scan_type
=
::
rocprim
::
block_scan
<
scan_by_key_pair
<
result_type
>
,
BlockSize
>
;
ROCPRIM_SHARED_MEMORY
struct
{
typename
discontinuity_type
::
storage_type
discontinuity
;
typename
scan_type
::
storage_type
scan
;
}
storage
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
carry_out
<
result_type
>
cs
[
ItemsPerThread
];
block_load_direct_blocked
(
flat_id
,
carry_outs
,
cs
,
batches
-
1
);
unsigned
int
destinations
[
ItemsPerThread
];
result_type
values
[
ItemsPerThread
];
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
destinations
[
i
]
=
cs
[
i
].
destination
;
values
[
i
]
=
cs
[
i
].
value
;
}
bool
head_flags
[
ItemsPerThread
];
bool
tail_flags
[
ItemsPerThread
];
::
rocprim
::
equal_to
<
unsigned
int
>
compare_op
;
// If a carry-out of the current batch has the same destination as previous batches,
// then we need to scan its value with values of those previous batches.
discontinuity_type
().
flag_heads_and_tails
(
head_flags
,
tail_flags
,
destinations
[
ItemsPerThread
-
1
],
// Do not always flag the last item in the block
destinations
,
guarded_key_flag_op
<
unsigned
int
,
decltype
(
compare_op
)
>
(
compare_op
,
batches
-
1
),
storage
.
discontinuity
);
scan_by_key_pair
<
result_type
>
pairs
[
ItemsPerThread
];
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
pairs
[
i
].
key
=
head_flags
[
i
];
pairs
[
i
].
value
=
values
[
i
];
}
scan_by_key_op
<
scan_by_key_pair
<
result_type
>
,
BinaryFunction
>
scan_op
(
reduce_op
);
scan_type
().
inclusive_scan
(
pairs
,
pairs
,
storage
.
scan
,
scan_op
);
// Scatter the last carry-out of each segment as carry-ins
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
if
(
tail_flags
[
i
])
{
const
unsigned
int
dst
=
destinations
[
i
];
const
result_type
aggregate
=
pairs
[
i
].
value
;
if
(
cs
[
i
].
next_has_carry_in
)
{
// The next batch continues the last segment from the current batch,
// combine two partial aggregates
aggregates_output
[
dst
]
=
reduce_op
(
aggregate
,
leading_aggregates
[
flat_id
*
ItemsPerThread
+
i
]);
}
else
{
// Overwrite the aggregate because the next batch starts with a different key
aggregates_output
[
dst
]
=
aggregate
;
}
}
}
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_BY_KEY_HPP_
3rdparty/cub/rocprim/device/detail/device_scan_by_key.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_
#include "device_scan_common.hpp"
#include "lookback_scan_state.hpp"
#include "ordered_block_id.hpp"
#include "../../block/block_discontinuity.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_scan.hpp"
#include "../../block/block_store.hpp"
#include "../../config.hpp"
#include "../../detail/binary_op_wrappers.hpp"
#include "../../intrinsics/thread.hpp"
#include "../../types/tuple.hpp"
#include <type_traits>
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
template
<
bool
Exclusive
,
unsigned
int
block_size
,
unsigned
int
items_per_thread
,
typename
key_type
,
typename
result_type
,
::
rocprim
::
block_load_method
load_keys_method
,
::
rocprim
::
block_load_method
load_values_method
>
struct
load_values_flagged
{
using
block_load_keys
=
::
rocprim
::
block_load
<
key_type
,
block_size
,
items_per_thread
,
load_keys_method
>
;
using
block_discontinuity
=
::
rocprim
::
block_discontinuity
<
key_type
,
block_size
>
;
using
block_load_values
=
::
rocprim
::
block_load
<
result_type
,
block_size
,
items_per_thread
,
load_keys_method
>
;
union
storage_type
{
struct
{
typename
block_load_keys
::
storage_type
load
;
typename
block_discontinuity
::
storage_type
flag
;
}
keys
;
typename
block_load_values
::
storage_type
load_values
;
};
// Load flagged values
// - if the scan is exlusive the last item of each segment (range where the keys compare equal)
// is flagged and reset to the initial value. Adding the last item of the range to the
// second to last using `headflag_scan_op_wrapper` will return the initial_value,
// which is exactly what should be saved at the start of the next range.
// - if the scan is inclusive, then the first item of each segment is marked, and it will
// restart the scan from that value
template
<
typename
KeyIterator
,
typename
ValueIterator
,
typename
CompareFunction
>
ROCPRIM_DEVICE
void
load
(
KeyIterator
keys_input
,
ValueIterator
values_input
,
CompareFunction
compare
,
const
result_type
initial_value
,
const
unsigned
int
flat_block_id
,
const
size_t
starting_block
,
const
size_t
number_of_blocks
,
const
unsigned
int
flat_thread_id
,
const
size_t
size
,
rocprim
::
tuple
<
result_type
,
bool
>
(
&
wrapped_values
)[
items_per_thread
],
storage_type
&
storage
)
{
constexpr
static
unsigned
int
items_per_block
=
items_per_thread
*
block_size
;
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_block
;
KeyIterator
block_keys
=
keys_input
+
block_offset
;
ValueIterator
block_values
=
values_input
+
block_offset
;
key_type
keys
[
items_per_thread
];
result_type
values
[
items_per_thread
];
bool
flags
[
items_per_thread
];
auto
not_equal
=
[
compare
](
const
auto
&
a
,
const
auto
&
b
)
mutable
{
return
!
compare
(
a
,
b
);
};
const
auto
flag_segment_boundaries
=
[
&
]()
{
if
(
Exclusive
)
{
const
key_type
tile_successor
=
starting_block
+
flat_block_id
<
number_of_blocks
-
1
?
block_keys
[
items_per_block
]
:
*
block_keys
;
block_discontinuity
{}.
flag_tails
(
flags
,
tile_successor
,
keys
,
not_equal
,
storage
.
keys
.
flag
);
}
else
{
const
key_type
tile_predecessor
=
starting_block
+
flat_block_id
>
0
?
block_keys
[
-
1
]
:
*
block_keys
;
block_discontinuity
{}.
flag_heads
(
flags
,
tile_predecessor
,
keys
,
not_equal
,
storage
.
keys
.
flag
);
}
};
if
(
starting_block
+
flat_block_id
<
number_of_blocks
-
1
)
{
block_load_keys
{}.
load
(
block_keys
,
keys
,
storage
.
keys
.
load
);
flag_segment_boundaries
();
// Reusing shared memory for loading values
::
rocprim
::
syncthreads
();
block_load_values
{}.
load
(
block_values
,
values
,
storage
.
load_values
);
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
++
i
)
{
rocprim
::
get
<
0
>
(
wrapped_values
[
i
])
=
(
Exclusive
&&
flags
[
i
])
?
initial_value
:
values
[
i
];
rocprim
::
get
<
1
>
(
wrapped_values
[
i
])
=
flags
[
i
];
}
}
else
{
const
unsigned
int
valid_in_last_block
=
static_cast
<
unsigned
int
>
(
size
-
items_per_block
*
(
number_of_blocks
-
1
));
block_load_keys
{}.
load
(
block_keys
,
keys
,
valid_in_last_block
,
*
block_keys
,
// Any value is okay, so discontinuity doesn't access undefined items
storage
.
keys
.
load
);
flag_segment_boundaries
();
// Reusing shared memory for loading values
::
rocprim
::
syncthreads
();
block_load_values
{}.
load
(
block_values
,
values
,
valid_in_last_block
,
storage
.
load_values
);
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
++
i
)
{
if
(
flat_thread_id
*
items_per_thread
+
i
>=
valid_in_last_block
)
{
break
;
}
rocprim
::
get
<
0
>
(
wrapped_values
[
i
])
=
(
Exclusive
&&
flags
[
i
])
?
initial_value
:
values
[
i
];
rocprim
::
get
<
1
>
(
wrapped_values
[
i
])
=
flags
[
i
];
}
}
}
};
template
<
unsigned
int
block_size
,
unsigned
int
items_per_thread
,
typename
result_type
,
::
rocprim
::
block_store_method
store_method
>
struct
unwrap_store
{
using
block_store_values
=
::
rocprim
::
block_store
<
result_type
,
block_size
,
items_per_thread
,
store_method
>
;
using
storage_type
=
typename
block_store_values
::
storage_type
;
template
<
typename
OutputIterator
>
ROCPRIM_DEVICE
void
store
(
OutputIterator
output
,
const
unsigned
int
flat_block_id
,
const
size_t
starting_block
,
const
size_t
number_of_blocks
,
const
unsigned
int
flat_thread_id
,
const
size_t
size
,
const
rocprim
::
tuple
<
result_type
,
bool
>
(
&
wrapped_values
)[
items_per_thread
],
storage_type
&
storage
)
{
constexpr
static
unsigned
int
items_per_block
=
items_per_thread
*
block_size
;
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_block
;
OutputIterator
block_output
=
output
+
block_offset
;
result_type
thread_values
[
items_per_thread
];
if
(
starting_block
+
flat_block_id
<
number_of_blocks
-
1
)
{
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
++
i
)
{
thread_values
[
i
]
=
rocprim
::
get
<
0
>
(
wrapped_values
[
i
]);
}
// Reusing shared memory from scan to perform store
rocprim
::
syncthreads
();
block_store_values
{}.
store
(
block_output
,
thread_values
,
storage
);
}
else
{
const
unsigned
int
valid_in_last_block
=
static_cast
<
unsigned
int
>
(
size
-
items_per_block
*
(
number_of_blocks
-
1
));
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
++
i
)
{
if
(
flat_thread_id
*
items_per_thread
+
i
>=
valid_in_last_block
)
{
break
;
}
thread_values
[
i
]
=
rocprim
::
get
<
0
>
(
wrapped_values
[
i
]);
}
// Reusing shared memory from scan to perform store
rocprim
::
syncthreads
();
block_store_values
{}.
store
(
block_output
,
thread_values
,
valid_in_last_block
,
storage
);
}
}
};
template
<
bool
Exclusive
,
typename
Config
,
typename
KeyInputIterator
,
typename
InputIterator
,
typename
OutputIterator
,
typename
ResultType
,
typename
CompareFunction
,
typename
BinaryFunction
,
typename
LookbackScanState
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
device_scan_by_key_kernel_impl
(
KeyInputIterator
keys
,
InputIterator
values
,
OutputIterator
output
,
ResultType
initial_value
,
const
CompareFunction
compare
,
const
BinaryFunction
scan_op
,
LookbackScanState
scan_state
,
const
size_t
size
,
const
size_t
starting_block
,
const
size_t
number_of_blocks
,
ordered_block_id
<
unsigned
int
>
ordered_bid
,
const
rocprim
::
tuple
<
ResultType
,
bool
>*
const
previous_last_value
)
{
using
result_type
=
ResultType
;
static_assert
(
std
::
is_same
<
rocprim
::
tuple
<
ResultType
,
bool
>
,
typename
LookbackScanState
::
value_type
>::
value
,
"value_type of LookbackScanState must be tuple of result type and flag"
);
constexpr
auto
block_size
=
Config
::
block_size
;
constexpr
auto
items_per_thread
=
Config
::
items_per_thread
;
constexpr
auto
load_keys_method
=
Config
::
block_load_method
;
constexpr
auto
load_values_method
=
load_keys_method
;
using
key_type
=
typename
std
::
iterator_traits
<
KeyInputIterator
>::
value_type
;
using
load_flagged
=
load_values_flagged
<
Exclusive
,
block_size
,
items_per_thread
,
key_type
,
result_type
,
load_keys_method
,
load_values_method
>
;
auto
wrapped_op
=
headflag_scan_op_wrapper
<
result_type
,
bool
,
BinaryFunction
>
{
scan_op
};
using
wrapped_type
=
rocprim
::
tuple
<
result_type
,
bool
>
;
using
block_scan_type
=
::
rocprim
::
block_scan
<
wrapped_type
,
block_size
,
Config
::
block_scan_method
>
;
constexpr
auto
store_method
=
Config
::
block_store_method
;
using
store_unwrap
=
unwrap_store
<
block_size
,
items_per_thread
,
result_type
,
store_method
>
;
using
order_bid_type
=
ordered_block_id
<
unsigned
int
>
;
ROCPRIM_SHARED_MEMORY
union
{
struct
{
typename
load_flagged
::
storage_type
load
;
typename
order_bid_type
::
storage_type
ordered_bid
;
};
typename
block_scan_type
::
storage_type
scan
;
typename
store_unwrap
::
storage_type
store
;
}
storage
;
const
auto
flat_thread_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
auto
flat_block_id
=
ordered_bid
.
get
(
flat_thread_id
,
storage
.
ordered_bid
);
// Load input
wrapped_type
wrapped_values
[
items_per_thread
];
load_flagged
{}.
load
(
keys
,
values
,
compare
,
initial_value
,
flat_block_id
,
starting_block
,
number_of_blocks
,
flat_thread_id
,
size
,
wrapped_values
,
storage
.
load
);
// Reusing the storage from load to perform the scan
::
rocprim
::
syncthreads
();
// Perform look back scan scan
if
(
flat_block_id
==
0
)
{
auto
wrapped_initial_value
=
rocprim
::
make_tuple
(
initial_value
,
false
);
// previous_last_value is used to pass the value from the previous grid, if this is a
// multi grid launch
if
(
previous_last_value
!=
nullptr
)
{
if
(
Exclusive
)
{
rocprim
::
get
<
0
>
(
wrapped_initial_value
)
=
rocprim
::
get
<
0
>
(
*
previous_last_value
);
}
else
if
(
flat_thread_id
==
0
)
{
wrapped_values
[
0
]
=
wrapped_op
(
*
previous_last_value
,
wrapped_values
[
0
]);
}
}
wrapped_type
reduction
;
lookback_block_scan
<
Exclusive
,
block_scan_type
>
(
wrapped_values
,
wrapped_initial_value
,
reduction
,
storage
.
scan
,
wrapped_op
);
if
(
flat_thread_id
==
0
)
{
scan_state
.
set_complete
(
flat_block_id
,
reduction
);
}
}
else
{
auto
prefix_op
=
lookback_scan_prefix_op
<
wrapped_type
,
decltype
(
wrapped_op
),
decltype
(
scan_state
)
>
{
flat_block_id
,
wrapped_op
,
scan_state
};
// Scan of block values
lookback_block_scan
<
Exclusive
,
block_scan_type
>
(
wrapped_values
,
storage
.
scan
,
prefix_op
,
wrapped_op
);
}
// Store output
// synchronization is inside the function after unwrapping
store_unwrap
{}.
store
(
output
,
flat_block_id
,
starting_block
,
number_of_blocks
,
flat_thread_id
,
size
,
wrapped_values
,
storage
.
store
);
}
}
// namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_
\ No newline at end of file
3rdparty/cub/rocprim/device/detail/device_scan_common.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_SCAN_COMMON_HPP_
#define ROCPRIM_DEVICE_SCAN_COMMON_HPP_
#include "../../config.hpp"
#include "../../intrinsics/thread.hpp"
#include "lookback_scan_state.hpp"
#include "ordered_block_id.hpp"
#include <cuda_runtime.h>
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
template
<
typename
LookBackScanState
>
ROCPRIM_KERNEL
__launch_bounds__
(
ROCPRIM_DEFAULT_MAX_BLOCK_SIZE
)
void
init_lookback_scan_state_kernel
(
LookBackScanState
lookback_scan_state
,
const
unsigned
int
number_of_blocks
,
ordered_block_id
<
unsigned
int
>
ordered_bid
,
unsigned
int
save_index
=
0
,
typename
LookBackScanState
::
value_type
*
const
save_dest
=
nullptr
)
{
const
unsigned
int
block_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
block_size
=
::
rocprim
::
detail
::
block_size
<
0
>
();
const
unsigned
int
block_thread_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
id
=
(
block_id
*
block_size
)
+
block_thread_id
;
// Reset ordered_block_id
if
(
id
==
0
)
{
ordered_bid
.
reset
();
}
// Save the reduction (i.e. the last prefix) from the previous user of lookback_scan_state
// If the thread that should reset it is participating then it saves the value before
// reseting, otherwise the first thread saves it (it won't be reset by any thread).
if
(
save_dest
!=
nullptr
&&
((
number_of_blocks
<=
save_index
&&
id
==
0
)
||
id
==
save_index
))
{
typename
LookBackScanState
::
value_type
value
;
typename
LookBackScanState
::
flag_type
dummy_flag
;
lookback_scan_state
.
get
(
save_index
,
dummy_flag
,
value
);
*
save_dest
=
value
;
}
// Initialize lookback scan status
lookback_scan_state
.
initialize_prefix
(
id
,
number_of_blocks
);
}
template
<
bool
Exclusive
,
class
BlockScan
,
class
T
,
unsigned
int
ItemsPerThread
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
lookback_block_scan
(
T
(
&
values
)[
ItemsPerThread
],
T
/* initial_value */
,
T
&
reduction
,
typename
BlockScan
::
storage_type
&
storage
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<!
Exclusive
>::
type
{
BlockScan
().
inclusive_scan
(
values
,
// input
values
,
// output
reduction
,
storage
,
scan_op
);
}
template
<
bool
Exclusive
,
class
BlockScan
,
class
T
,
unsigned
int
ItemsPerThread
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
lookback_block_scan
(
T
(
&
values
)[
ItemsPerThread
],
T
initial_value
,
T
&
reduction
,
typename
BlockScan
::
storage_type
&
storage
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<
Exclusive
>::
type
{
BlockScan
().
exclusive_scan
(
values
,
// input
values
,
// output
initial_value
,
reduction
,
storage
,
scan_op
);
reduction
=
scan_op
(
initial_value
,
reduction
);
}
template
<
bool
Exclusive
,
class
BlockScan
,
class
T
,
unsigned
int
ItemsPerThread
,
class
PrefixCallback
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
lookback_block_scan
(
T
(
&
values
)[
ItemsPerThread
],
typename
BlockScan
::
storage_type
&
storage
,
PrefixCallback
&
prefix_callback_op
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<!
Exclusive
>::
type
{
BlockScan
().
inclusive_scan
(
values
,
// input
values
,
// output
storage
,
prefix_callback_op
,
scan_op
);
}
template
<
bool
Exclusive
,
class
BlockScan
,
class
T
,
unsigned
int
ItemsPerThread
,
class
PrefixCallback
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
lookback_block_scan
(
T
(
&
values
)[
ItemsPerThread
],
typename
BlockScan
::
storage_type
&
storage
,
PrefixCallback
&
prefix_callback_op
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<
Exclusive
>::
type
{
BlockScan
().
exclusive_scan
(
values
,
// input
values
,
// output
storage
,
prefix_callback_op
,
scan_op
);
}
}
// namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_SCAN_COMMON_HPP_
\ No newline at end of file
3rdparty/cub/rocprim/device/detail/device_scan_lookback.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_LOOKBACK_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_LOOKBACK_HPP_
#include <type_traits>
#include <iterator>
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../types.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_store.hpp"
#include "../../block/block_scan.hpp"
#include "device_scan_common.hpp"
#include "lookback_scan_state.hpp"
#include "ordered_block_id.hpp"
BEGIN_ROCPRIM_NAMESPACE
// Single pass prefix scan was implemented based on:
// Merrill, D. and Garland, M. Single-pass Parallel Prefix Scan with Decoupled Look-back.
// Technical Report NVR2016-001, NVIDIA Research. Mar. 2016.
namespace
detail
{
template
<
bool
Exclusive
,
class
Config
,
class
InputIterator
,
class
OutputIterator
,
class
BinaryFunction
,
class
ResultType
,
class
LookbackScanState
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
lookback_scan_kernel_impl
(
InputIterator
input
,
OutputIterator
output
,
const
size_t
size
,
ResultType
initial_value
,
BinaryFunction
scan_op
,
LookbackScanState
scan_state
,
const
unsigned
int
number_of_blocks
,
ordered_block_id
<
unsigned
int
>
ordered_bid
,
ResultType
*
previous_last_element
=
nullptr
,
ResultType
*
new_last_element
=
nullptr
,
bool
override_first_value
=
false
,
bool
save_last_value
=
false
)
{
using
result_type
=
ResultType
;
static_assert
(
std
::
is_same
<
result_type
,
typename
LookbackScanState
::
value_type
>::
value
,
"value_type of LookbackScanState must be result_type"
);
constexpr
auto
block_size
=
Config
::
block_size
;
constexpr
auto
items_per_thread
=
Config
::
items_per_thread
;
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
using
block_load_type
=
::
rocprim
::
block_load
<
result_type
,
block_size
,
items_per_thread
,
Config
::
block_load_method
>
;
using
block_store_type
=
::
rocprim
::
block_store
<
result_type
,
block_size
,
items_per_thread
,
Config
::
block_store_method
>
;
using
block_scan_type
=
::
rocprim
::
block_scan
<
result_type
,
block_size
,
Config
::
block_scan_method
>
;
using
order_bid_type
=
ordered_block_id
<
unsigned
int
>
;
using
lookback_scan_prefix_op_type
=
lookback_scan_prefix_op
<
result_type
,
BinaryFunction
,
LookbackScanState
>
;
ROCPRIM_SHARED_MEMORY
struct
{
typename
order_bid_type
::
storage_type
ordered_bid
;
union
{
typename
block_load_type
::
storage_type
load
;
typename
block_store_type
::
storage_type
store
;
typename
block_scan_type
::
storage_type
scan
;
};
}
storage
;
const
auto
flat_block_thread_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
auto
flat_block_id
=
ordered_bid
.
get
(
flat_block_thread_id
,
storage
.
ordered_bid
);
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_block
;
const
auto
valid_in_last_block
=
size
-
items_per_block
*
(
number_of_blocks
-
1
);
// For input values
result_type
values
[
items_per_thread
];
// load input values into values
if
(
flat_block_id
==
(
number_of_blocks
-
1
))
// last block
{
block_load_type
()
.
load
(
input
+
block_offset
,
values
,
valid_in_last_block
,
*
(
input
+
block_offset
),
storage
.
load
);
}
else
{
block_load_type
()
.
load
(
input
+
block_offset
,
values
,
storage
.
load
);
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
if
(
flat_block_id
==
0
)
{
// override_first_value only true when the first chunk already processed
// and input iterator starts from an offset.
if
(
override_first_value
)
{
if
(
Exclusive
)
initial_value
=
scan_op
(
previous_last_element
[
0
],
static_cast
<
result_type
>
(
*
(
input
-
1
)));
else
if
(
flat_block_thread_id
==
0
)
values
[
0
]
=
scan_op
(
previous_last_element
[
0
],
values
[
0
]);
}
result_type
reduction
;
lookback_block_scan
<
Exclusive
,
block_scan_type
>
(
values
,
// input/output
initial_value
,
reduction
,
storage
.
scan
,
scan_op
);
if
(
flat_block_thread_id
==
0
)
{
scan_state
.
set_complete
(
flat_block_id
,
reduction
);
}
}
else
{
// Scan of block values
auto
prefix_op
=
lookback_scan_prefix_op_type
(
flat_block_id
,
scan_op
,
scan_state
);
lookback_block_scan
<
Exclusive
,
block_scan_type
>
(
values
,
// input/output
storage
.
scan
,
prefix_op
,
scan_op
);
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
// Save values into output array
if
(
flat_block_id
==
(
number_of_blocks
-
1
))
// last block
{
block_store_type
()
.
store
(
output
+
block_offset
,
values
,
valid_in_last_block
,
storage
.
store
);
if
(
save_last_value
&&
(
::
rocprim
::
detail
::
block_thread_id
<
0
>
()
==
(
valid_in_last_block
-
1
)
/
items_per_thread
))
{
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
i
++
)
{
if
(
i
==
(
valid_in_last_block
-
1
)
%
items_per_thread
)
{
new_last_element
[
0
]
=
values
[
i
];
}
}
}
}
else
{
block_store_type
()
.
store
(
output
+
block_offset
,
values
,
storage
.
store
);
}
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_LOOKBACK_HPP_
3rdparty/cub/rocprim/device/detail/device_scan_reduce_then_scan.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_REDUCE_THEN_SCAN_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_REDUCE_THEN_SCAN_HPP_
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../types.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_store.hpp"
#include "../../block/block_scan.hpp"
#include "../../block/block_reduce.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
// Helper functions for performing exclusive or inclusive
// block scan in single_scan.
template
<
bool
Exclusive
,
class
BlockScan
,
class
T
,
unsigned
int
ItemsPerThread
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
single_scan_block_scan
(
T
(
&
input
)[
ItemsPerThread
],
T
(
&
output
)[
ItemsPerThread
],
T
initial_value
,
typename
BlockScan
::
storage_type
&
storage
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<
Exclusive
>::
type
{
BlockScan
()
.
exclusive_scan
(
input
,
// input
output
,
// output
initial_value
,
storage
,
scan_op
);
}
template
<
bool
Exclusive
,
class
BlockScan
,
class
T
,
unsigned
int
ItemsPerThread
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
single_scan_block_scan
(
T
(
&
input
)[
ItemsPerThread
],
T
(
&
output
)[
ItemsPerThread
],
T
initial_value
,
typename
BlockScan
::
storage_type
&
storage
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<!
Exclusive
>::
type
{
(
void
)
initial_value
;
BlockScan
()
.
inclusive_scan
(
input
,
// input
output
,
// output
storage
,
scan_op
);
}
template
<
bool
Exclusive
,
class
Config
,
class
InputIterator
,
class
OutputIterator
,
class
BinaryFunction
,
class
ResultType
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
single_scan_kernel_impl
(
InputIterator
input
,
const
size_t
input_size
,
ResultType
initial_value
,
OutputIterator
output
,
BinaryFunction
scan_op
)
{
constexpr
unsigned
int
block_size
=
Config
::
block_size
;
constexpr
unsigned
int
items_per_thread
=
Config
::
items_per_thread
;
using
result_type
=
ResultType
;
using
block_load_type
=
::
rocprim
::
block_load
<
result_type
,
block_size
,
items_per_thread
,
Config
::
block_load_method
>
;
using
block_store_type
=
::
rocprim
::
block_store
<
result_type
,
block_size
,
items_per_thread
,
Config
::
block_store_method
>
;
using
block_scan_type
=
::
rocprim
::
block_scan
<
result_type
,
block_size
,
Config
::
block_scan_method
>
;
ROCPRIM_SHARED_MEMORY
union
{
typename
block_load_type
::
storage_type
load
;
typename
block_store_type
::
storage_type
store
;
typename
block_scan_type
::
storage_type
scan
;
}
storage
;
result_type
values
[
items_per_thread
];
// load input values into values
block_load_type
()
.
load
(
input
,
values
,
input_size
,
*
(
input
),
storage
.
load
);
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
single_scan_block_scan
<
Exclusive
,
block_scan_type
>
(
values
,
// input
values
,
// output
initial_value
,
storage
.
scan
,
scan_op
);
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
// Save values into output array
block_store_type
()
.
store
(
output
,
values
,
input_size
,
storage
.
store
);
}
// Calculates block prefixes that will be used in final_scan
// when performing block scan operations.
template
<
class
Config
,
class
InputIterator
,
class
BinaryFunction
,
class
ResultType
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
block_reduce_kernel_impl
(
InputIterator
input
,
BinaryFunction
scan_op
,
ResultType
*
block_prefixes
)
{
constexpr
unsigned
int
block_size
=
Config
::
block_size
;
constexpr
unsigned
int
items_per_thread
=
Config
::
items_per_thread
;
using
result_type
=
ResultType
;
using
block_reduce_type
=
::
rocprim
::
block_reduce
<
result_type
,
block_size
,
::
rocprim
::
block_reduce_algorithm
::
using_warp_reduce
>
;
using
block_load_type
=
::
rocprim
::
block_load
<
result_type
,
block_size
,
items_per_thread
,
Config
::
block_load_method
>
;
ROCPRIM_SHARED_MEMORY
union
{
typename
block_load_type
::
storage_type
load
;
typename
block_reduce_type
::
storage_type
reduce
;
}
storage
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
flat_block_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_thread
*
block_size
;
// For input values
result_type
values
[
items_per_thread
];
result_type
block_prefix
;
block_load_type
()
.
load
(
input
+
block_offset
,
values
,
storage
.
load
);
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
block_reduce_type
()
.
reduce
(
values
,
// input
block_prefix
,
// output
storage
.
reduce
,
scan_op
);
// Save block prefix
if
(
flat_id
==
0
)
{
block_prefixes
[
flat_block_id
]
=
block_prefix
;
}
}
// Helper functions for performing exclusive or inclusive
// block scan operation in final_scan
template
<
bool
Exclusive
,
class
BlockScan
,
class
T
,
unsigned
int
ItemsPerThread
,
class
ResultType
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
final_scan_block_scan
(
const
unsigned
int
flat_block_id
,
T
(
&
input
)[
ItemsPerThread
],
T
(
&
output
)[
ItemsPerThread
],
T
initial_value
,
ResultType
*
block_prefixes
,
typename
BlockScan
::
storage_type
&
storage
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<
Exclusive
>::
type
{
if
(
flat_block_id
!=
0
)
{
// Include initial value in block prefix
initial_value
=
scan_op
(
initial_value
,
block_prefixes
[
flat_block_id
-
1
]
);
}
BlockScan
()
.
exclusive_scan
(
input
,
// input
output
,
// output
initial_value
,
storage
,
scan_op
);
}
template
<
bool
Exclusive
,
class
BlockScan
,
class
T
,
unsigned
int
ItemsPerThread
,
class
ResultType
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
final_scan_block_scan
(
const
unsigned
int
flat_block_id
,
T
(
&
input
)[
ItemsPerThread
],
T
(
&
output
)[
ItemsPerThread
],
T
initial_value
,
ResultType
*
block_prefixes
,
typename
BlockScan
::
storage_type
&
storage
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<!
Exclusive
>::
type
{
(
void
)
initial_value
;
if
(
flat_block_id
==
0
)
{
BlockScan
()
.
inclusive_scan
(
input
,
// input
output
,
// output
storage
,
scan_op
);
}
else
{
auto
block_prefix_op
=
[
&
block_prefixes
,
&
flat_block_id
](
const
T
&
/*not used*/
)
{
return
block_prefixes
[
flat_block_id
-
1
];
};
BlockScan
()
.
inclusive_scan
(
input
,
// input
output
,
// output
storage
,
block_prefix_op
,
scan_op
);
}
}
template
<
bool
Exclusive
,
class
Config
,
class
InputIterator
,
class
OutputIterator
,
class
BinaryFunction
,
class
ResultType
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
final_scan_kernel_impl
(
InputIterator
input
,
const
size_t
input_size
,
OutputIterator
output
,
ResultType
initial_value
,
BinaryFunction
scan_op
,
ResultType
*
block_prefixes
,
ResultType
*
previous_last_element
=
nullptr
,
ResultType
*
new_last_element
=
nullptr
,
bool
override_first_value
=
false
,
bool
save_last_value
=
false
)
{
constexpr
unsigned
int
block_size
=
Config
::
block_size
;
constexpr
unsigned
int
items_per_thread
=
Config
::
items_per_thread
;
using
result_type
=
ResultType
;
using
block_load_type
=
::
rocprim
::
block_load
<
result_type
,
block_size
,
items_per_thread
,
Config
::
block_load_method
>
;
using
block_store_type
=
::
rocprim
::
block_store
<
result_type
,
block_size
,
items_per_thread
,
Config
::
block_store_method
>
;
using
block_scan_type
=
::
rocprim
::
block_scan
<
result_type
,
block_size
,
Config
::
block_scan_method
>
;
ROCPRIM_SHARED_MEMORY
union
{
typename
block_load_type
::
storage_type
load
;
typename
block_store_type
::
storage_type
store
;
typename
block_scan_type
::
storage_type
scan
;
}
storage
;
// It's assumed kernel is executed in 1D
const
unsigned
int
flat_block_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_block
;
// TODO: number_of_blocks can be calculated on host
const
unsigned
int
number_of_blocks
=
(
input_size
+
items_per_block
-
1
)
/
items_per_block
;
// For input values
result_type
values
[
items_per_thread
];
// TODO: valid_in_last_block can be calculated on host
auto
valid_in_last_block
=
input_size
-
items_per_block
*
(
number_of_blocks
-
1
);
// load input values into values
if
(
flat_block_id
==
(
number_of_blocks
-
1
))
// last block
{
block_load_type
()
.
load
(
input
+
block_offset
,
values
,
valid_in_last_block
,
*
(
input
+
block_offset
),
storage
.
load
);
}
else
{
block_load_type
()
.
load
(
input
+
block_offset
,
values
,
storage
.
load
);
}
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
// override_first_value only true when the first chunk already processed
// and input iterator starts from an offset.
if
(
override_first_value
&&
flat_block_id
==
0
)
{
if
(
Exclusive
)
initial_value
=
scan_op
(
previous_last_element
[
0
],
*
(
input
-
1
));
else
if
(
::
rocprim
::
detail
::
block_thread_id
<
0
>
()
==
0
)
values
[
0
]
=
scan_op
(
previous_last_element
[
0
],
values
[
0
]);
}
final_scan_block_scan
<
Exclusive
,
block_scan_type
>
(
flat_block_id
,
values
,
// input
values
,
// output
initial_value
,
block_prefixes
,
storage
.
scan
,
scan_op
);
::
rocprim
::
syncthreads
();
// sync threads to reuse shared memory
// Save values into output array
if
(
flat_block_id
==
(
number_of_blocks
-
1
))
// last block
{
block_store_type
()
.
store
(
output
+
block_offset
,
values
,
valid_in_last_block
,
storage
.
store
);
if
(
save_last_value
&&
(
::
rocprim
::
detail
::
block_thread_id
<
0
>
()
==
(
valid_in_last_block
-
1
)
/
items_per_thread
))
{
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
i
++
)
{
if
(
i
==
(
valid_in_last_block
-
1
)
%
items_per_thread
)
{
new_last_element
[
0
]
=
values
[
i
];
}
}
}
}
else
{
block_store_type
()
.
store
(
output
+
block_offset
,
values
,
storage
.
store
);
}
}
// Returns size of temporary storage in bytes.
template
<
class
T
>
size_t
scan_get_temporary_storage_bytes
(
size_t
input_size
,
size_t
items_per_block
)
{
if
(
input_size
<=
items_per_block
)
{
return
0
;
}
auto
size
=
(
input_size
+
items_per_block
-
1
)
/
(
items_per_block
);
return
size
*
sizeof
(
T
)
+
scan_get_temporary_storage_bytes
<
T
>
(
size
,
items_per_block
);
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_REDUCE_THEN_SCAN_HPP_
3rdparty/cub/rocprim/device/detail/device_segmented_radix_sort.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../types.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_store.hpp"
#include "../../block/block_scan.hpp"
#include "../../warp/warp_load.hpp"
#include "../../warp/warp_sort.hpp"
#include "../../warp/warp_store.hpp"
#include "../device_segmented_radix_sort_config.hpp"
#include "device_radix_sort.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
template
<
class
Key
,
class
Value
,
unsigned
int
WarpSize
,
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
unsigned
int
RadixBits
,
bool
Descending
>
class
segmented_radix_sort_helper
{
static
constexpr
unsigned
int
radix_size
=
1
<<
RadixBits
;
using
key_type
=
Key
;
using
value_type
=
Value
;
using
count_helper_type
=
radix_digit_count_helper
<
WarpSize
,
BlockSize
,
ItemsPerThread
,
RadixBits
,
Descending
>
;
using
scan_type
=
typename
::
rocprim
::
block_scan
<
unsigned
int
,
radix_size
>
;
using
sort_and_scatter_helper
=
radix_sort_and_scatter_helper
<
BlockSize
,
ItemsPerThread
,
RadixBits
,
Descending
,
key_type
,
value_type
,
unsigned
int
>
;
public:
union
storage_type
{
typename
segmented_radix_sort_helper
<
Key
,
Value
,
WarpSize
,
BlockSize
,
ItemsPerThread
,
RadixBits
,
Descending
>::
count_helper_type
::
storage_type
count_helper
;
typename
segmented_radix_sort_helper
<
Key
,
Value
,
WarpSize
,
BlockSize
,
ItemsPerThread
,
RadixBits
,
Descending
>::
sort_and_scatter_helper
::
storage_type
sort_and_scatter_helper
;
};
template
<
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
sort
(
KeysInputIterator
keys_input
,
key_type
*
keys_tmp
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
value_type
*
values_tmp
,
ValuesOutputIterator
values_output
,
bool
to_output
,
unsigned
int
begin_offset
,
unsigned
int
end_offset
,
unsigned
int
bit
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
,
storage_type
&
storage
)
{
// Handle cases when (end_bit - bit) is not divisible by radix_bits, i.e. the last
// iteration has a shorter mask.
const
unsigned
int
current_radix_bits
=
::
rocprim
::
min
(
RadixBits
,
end_bit
-
bit
);
const
bool
is_first_iteration
=
(
bit
==
begin_bit
);
if
(
is_first_iteration
)
{
if
(
to_output
)
{
sort
(
keys_input
,
keys_output
,
values_input
,
values_output
,
begin_offset
,
end_offset
,
bit
,
current_radix_bits
,
storage
);
}
else
{
sort
(
keys_input
,
keys_tmp
,
values_input
,
values_tmp
,
begin_offset
,
end_offset
,
bit
,
current_radix_bits
,
storage
);
}
}
else
{
if
(
to_output
)
{
sort
(
keys_tmp
,
keys_output
,
values_tmp
,
values_output
,
begin_offset
,
end_offset
,
bit
,
current_radix_bits
,
storage
);
}
else
{
sort
(
keys_output
,
keys_tmp
,
values_output
,
values_tmp
,
begin_offset
,
end_offset
,
bit
,
current_radix_bits
,
storage
);
}
}
}
// When all iterators are raw pointers, this overload is used to minimize code duplication in the kernel
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
sort
(
key_type
*
keys_input
,
key_type
*
keys_tmp
,
key_type
*
keys_output
,
value_type
*
values_input
,
value_type
*
values_tmp
,
value_type
*
values_output
,
bool
to_output
,
unsigned
int
begin_offset
,
unsigned
int
end_offset
,
unsigned
int
bit
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
,
storage_type
&
storage
)
{
// Handle cases when (end_bit - bit) is not divisible by radix_bits, i.e. the last
// iteration has a shorter mask.
const
unsigned
int
current_radix_bits
=
::
rocprim
::
min
(
RadixBits
,
end_bit
-
bit
);
const
bool
is_first_iteration
=
(
bit
==
begin_bit
);
key_type
*
current_keys_input
;
key_type
*
current_keys_output
;
value_type
*
current_values_input
;
value_type
*
current_values_output
;
if
(
is_first_iteration
)
{
if
(
to_output
)
{
current_keys_input
=
keys_input
;
current_keys_output
=
keys_output
;
current_values_input
=
values_input
;
current_values_output
=
values_output
;
}
else
{
current_keys_input
=
keys_input
;
current_keys_output
=
keys_tmp
;
current_values_input
=
values_input
;
current_values_output
=
values_tmp
;
}
}
else
{
if
(
to_output
)
{
current_keys_input
=
keys_tmp
;
current_keys_output
=
keys_output
;
current_values_input
=
values_tmp
;
current_values_output
=
values_output
;
}
else
{
current_keys_input
=
keys_output
;
current_keys_output
=
keys_tmp
;
current_values_input
=
values_output
;
current_values_output
=
values_tmp
;
}
}
sort
(
current_keys_input
,
current_keys_output
,
current_values_input
,
current_values_output
,
begin_offset
,
end_offset
,
bit
,
current_radix_bits
,
storage
);
}
private:
template
<
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
sort
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
unsigned
int
begin_offset
,
unsigned
int
end_offset
,
unsigned
int
bit
,
unsigned
int
current_radix_bits
,
storage_type
&
storage
)
{
unsigned
int
digit_count
;
count_helper_type
().
count_digits
(
keys_input
,
begin_offset
,
end_offset
,
bit
,
current_radix_bits
,
storage
.
count_helper
,
digit_count
);
unsigned
int
digit_start
;
scan_type
().
exclusive_scan
(
digit_count
,
digit_start
,
0
);
digit_start
+=
begin_offset
;
::
rocprim
::
syncthreads
();
sort_and_scatter_helper
().
sort_and_scatter
(
keys_input
,
keys_output
,
values_input
,
values_output
,
begin_offset
,
end_offset
,
bit
,
current_radix_bits
,
digit_start
,
storage
.
sort_and_scatter_helper
);
::
rocprim
::
syncthreads
();
}
};
template
<
class
Key
,
class
Value
,
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
bool
Descending
>
class
segmented_radix_sort_single_block_helper
{
using
key_type
=
Key
;
using
value_type
=
Value
;
using
key_codec
=
radix_key_codec
<
key_type
,
Descending
>
;
using
bit_key_type
=
typename
key_codec
::
bit_key_type
;
using
keys_load_type
=
::
rocprim
::
block_load
<
key_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
values_load_type
=
::
rocprim
::
block_load
<
value_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_load_method
::
block_load_transpose
>
;
using
sort_type
=
::
rocprim
::
block_radix_sort
<
key_type
,
BlockSize
,
ItemsPerThread
,
value_type
>
;
using
keys_store_type
=
::
rocprim
::
block_store
<
key_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_store_method
::
block_store_transpose
>
;
using
values_store_type
=
::
rocprim
::
block_store
<
value_type
,
BlockSize
,
ItemsPerThread
,
::
rocprim
::
block_store_method
::
block_store_transpose
>
;
static
constexpr
bool
with_values
=
!
std
::
is_same
<
value_type
,
::
rocprim
::
empty_type
>::
value
;
public:
union
storage_type
{
typename
keys_load_type
::
storage_type
keys_load
;
typename
values_load_type
::
storage_type
values_load
;
typename
sort_type
::
storage_type
sort
;
typename
keys_store_type
::
storage_type
keys_store
;
typename
values_store_type
::
storage_type
values_store
;
};
template
<
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
sort
(
KeysInputIterator
keys_input
,
key_type
*
keys_tmp
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
value_type
*
values_tmp
,
ValuesOutputIterator
values_output
,
bool
to_output
,
unsigned
int
begin_offset
,
unsigned
int
end_offset
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
,
storage_type
&
storage
)
{
if
(
to_output
)
{
sort
(
keys_input
,
keys_output
,
values_input
,
values_output
,
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
storage
);
}
else
{
sort
(
keys_input
,
keys_tmp
,
values_input
,
values_tmp
,
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
storage
);
}
}
// When all iterators are raw pointers, this overload is used to minimize code duplication in the kernel
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
sort
(
key_type
*
keys_input
,
key_type
*
keys_tmp
,
key_type
*
keys_output
,
value_type
*
values_input
,
value_type
*
values_tmp
,
value_type
*
values_output
,
bool
to_output
,
unsigned
int
begin_offset
,
unsigned
int
end_offset
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
,
storage_type
&
storage
)
{
sort
(
keys_input
,
(
to_output
?
keys_output
:
keys_tmp
),
values_input
,
(
to_output
?
values_output
:
values_tmp
),
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
storage
);
}
template
<
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
sort
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
unsigned
int
begin_offset
,
unsigned
int
end_offset
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
,
storage_type
&
storage
)
{
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
using
shorter_single_block_helper
=
segmented_radix_sort_single_block_helper
<
key_type
,
value_type
,
BlockSize
,
ItemsPerThread
/
2
,
Descending
>
;
// Segment is longer than supported by this function
if
(
end_offset
-
begin_offset
>
items_per_block
)
{
return
false
;
}
// Recursively chech if it is possible to sort the segment using fewer items per thread
const
bool
processed_by_shorter
=
shorter_single_block_helper
().
sort
(
keys_input
,
keys_output
,
values_input
,
values_output
,
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
reinterpret_cast
<
typename
shorter_single_block_helper
::
storage_type
&>
(
storage
)
);
if
(
processed_by_shorter
)
{
return
true
;
}
key_type
keys
[
ItemsPerThread
];
value_type
values
[
ItemsPerThread
];
const
unsigned
int
valid_count
=
end_offset
-
begin_offset
;
// Sort will leave "invalid" (out of size) items at the end of the sorted sequence
const
key_type
out_of_bounds
=
key_codec
::
decode
(
bit_key_type
(
-
1
));
keys_load_type
().
load
(
keys_input
+
begin_offset
,
keys
,
valid_count
,
out_of_bounds
,
storage
.
keys_load
);
if
(
with_values
)
{
::
rocprim
::
syncthreads
();
values_load_type
().
load
(
values_input
+
begin_offset
,
values
,
valid_count
,
storage
.
values_load
);
}
::
rocprim
::
syncthreads
();
sort_block
<
Descending
>
(
sort_type
(),
keys
,
values
,
storage
.
sort
,
begin_bit
,
end_bit
);
::
rocprim
::
syncthreads
();
keys_store_type
().
store
(
keys_output
+
begin_offset
,
keys
,
valid_count
,
storage
.
keys_store
);
if
(
with_values
)
{
::
rocprim
::
syncthreads
();
values_store_type
().
store
(
values_output
+
begin_offset
,
values
,
valid_count
,
storage
.
values_store
);
}
return
true
;
}
};
template
<
class
Key
,
class
Value
,
unsigned
int
BlockSize
,
bool
Descending
>
class
segmented_radix_sort_single_block_helper
<
Key
,
Value
,
BlockSize
,
0
,
Descending
>
{
public:
struct
storage_type
{
};
template
<
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
bool
sort
(
KeysInputIterator
,
KeysOutputIterator
,
ValuesInputIterator
,
ValuesOutputIterator
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
storage_type
&
)
{
// It can't sort anything because ItemsPerThread is 0.
// The segment will be sorted by the calles (i.e. using ItemsPerThread = 1)
return
false
;
}
};
template
<
unsigned
int
LogicalWarpSize
,
unsigned
int
ItemsPerThread
,
unsigned
int
BlockSize
>
struct
WarpSortHelperConfig
{
static
constexpr
unsigned
int
logical_warp_size
=
LogicalWarpSize
;
static
constexpr
unsigned
int
items_per_thread
=
ItemsPerThread
;
static
constexpr
unsigned
int
block_size
=
BlockSize
;
};
struct
DisabledWarpSortHelperConfig
{
static
constexpr
unsigned
int
logical_warp_size
=
1
;
static
constexpr
unsigned
int
items_per_thread
=
1
;
static
constexpr
unsigned
int
block_size
=
1
;
};
template
<
class
Config
>
using
select_warp_sort_helper_config_small_t
=
std
::
conditional_t
<
std
::
is_same
<
DisabledWarpSortConfig
,
Config
>::
value
,
DisabledWarpSortHelperConfig
,
WarpSortHelperConfig
<
Config
::
logical_warp_size_small
,
Config
::
items_per_thread_small
,
Config
::
block_size_small
>>
;
template
<
class
Config
>
using
select_warp_sort_helper_config_medium_t
=
std
::
conditional_t
<
std
::
is_same
<
DisabledWarpSortConfig
,
Config
>::
value
,
DisabledWarpSortHelperConfig
,
WarpSortHelperConfig
<
Config
::
logical_warp_size_medium
,
Config
::
items_per_thread_medium
,
Config
::
block_size_medium
>>
;
template
<
class
Config
,
class
Key
,
class
Value
,
bool
Descending
,
class
Enable
=
void
>
struct
segmented_warp_sort_helper
{
static
constexpr
unsigned
int
items_per_warp
=
0
;
using
storage_type
=
::
rocprim
::
empty_type
;
template
<
class
...
Args
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
sort
(
Args
&&
...)
{
}
};
template
<
class
Config
,
class
Key
,
class
Value
,
bool
Descending
>
class
segmented_warp_sort_helper
<
Config
,
Key
,
Value
,
Descending
,
std
::
enable_if_t
<!
std
::
is_same
<
DisabledWarpSortHelperConfig
,
Config
>::
value
>>
{
static
constexpr
unsigned
int
logical_warp_size
=
Config
::
logical_warp_size
;
static
constexpr
unsigned
int
items_per_thread
=
Config
::
items_per_thread
;
using
key_type
=
Key
;
using
value_type
=
Value
;
using
key_codec
=
::
rocprim
::
detail
::
radix_key_codec
<
key_type
,
Descending
>
;
using
bit_key_type
=
typename
key_codec
::
bit_key_type
;
using
keys_load_type
=
::
rocprim
::
warp_load
<
key_type
,
items_per_thread
,
logical_warp_size
,
::
rocprim
::
warp_load_method
::
warp_load_striped
>
;
using
values_load_type
=
::
rocprim
::
warp_load
<
value_type
,
items_per_thread
,
logical_warp_size
,
::
rocprim
::
warp_load_method
::
warp_load_striped
>
;
using
keys_store_type
=
::
rocprim
::
warp_store
<
key_type
,
items_per_thread
,
logical_warp_size
>
;
using
values_store_type
=
::
rocprim
::
warp_store
<
value_type
,
items_per_thread
,
logical_warp_size
>
;
template
<
bool
UseRadixMask
>
using
radix_comparator_type
=
::
rocprim
::
detail
::
radix_merge_compare
<
Descending
,
UseRadixMask
,
key_type
>
;
using
stable_key_type
=
::
rocprim
::
tuple
<
key_type
,
unsigned
int
>
;
using
sort_type
=
::
rocprim
::
warp_sort
<
stable_key_type
,
logical_warp_size
,
value_type
>
;
static
constexpr
bool
with_values
=
!
std
::
is_same
<
value_type
,
::
rocprim
::
empty_type
>::
value
;
template
<
class
ComparatorT
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
decltype
(
auto
)
make_stable_comparator
(
ComparatorT
comparator
)
{
return
[
comparator
](
const
stable_key_type
&
a
,
const
stable_key_type
&
b
)
->
bool
{
const
bool
ab
=
comparator
(
rocprim
::
get
<
0
>
(
a
),
rocprim
::
get
<
0
>
(
b
));
const
bool
ba
=
comparator
(
rocprim
::
get
<
0
>
(
b
),
rocprim
::
get
<
0
>
(
a
));
return
ab
||
(
!
ba
&&
(
rocprim
::
get
<
1
>
(
a
)
<
rocprim
::
get
<
1
>
(
b
)));
};
}
public:
static
constexpr
unsigned
int
items_per_warp
=
items_per_thread
*
logical_warp_size
;
union
storage_type
{
typename
keys_load_type
::
storage_type
keys_load
;
typename
values_load_type
::
storage_type
values_load
;
typename
keys_store_type
::
storage_type
keys_store
;
typename
values_store_type
::
storage_type
values_store
;
typename
sort_type
::
storage_type
sort
;
};
template
<
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
sort
(
KeysInputIterator
keys_input
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
ValuesOutputIterator
values_output
,
unsigned
int
begin_offset
,
unsigned
int
end_offset
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
,
storage_type
&
storage
)
{
const
unsigned
int
num_items
=
end_offset
-
begin_offset
;
const
key_type
out_of_bounds
=
key_codec
::
decode
(
bit_key_type
(
-
1
));
key_type
keys
[
items_per_thread
];
stable_key_type
stable_keys
[
items_per_thread
];
value_type
values
[
items_per_thread
];
keys_load_type
().
load
(
keys_input
+
begin_offset
,
keys
,
num_items
,
out_of_bounds
,
storage
.
keys_load
);
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
i
++
)
{
::
rocprim
::
get
<
0
>
(
stable_keys
[
i
])
=
keys
[
i
];
::
rocprim
::
get
<
1
>
(
stable_keys
[
i
])
=
::
rocprim
::
detail
::
logical_lane_id
<
logical_warp_size
>
()
+
logical_warp_size
*
i
;
}
if
(
with_values
)
{
::
rocprim
::
wave_barrier
();
values_load_type
().
load
(
values_input
+
begin_offset
,
values
,
num_items
,
storage
.
values_load
);
}
::
rocprim
::
wave_barrier
();
if
(
begin_bit
==
0
&&
end_bit
==
8
*
sizeof
(
key_type
))
{
sort_type
().
sort
(
stable_keys
,
values
,
storage
.
sort
,
make_stable_comparator
(
radix_comparator_type
<
false
>
{}));
}
else
{
radix_comparator_type
<
true
>
comparator
(
begin_bit
,
end_bit
-
begin_bit
);
sort_type
().
sort
(
stable_keys
,
values
,
storage
.
sort
,
make_stable_comparator
(
comparator
));
}
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
i
++
)
{
keys
[
i
]
=
::
rocprim
::
get
<
0
>
(
stable_keys
[
i
]);
}
::
rocprim
::
wave_barrier
();
keys_store_type
().
store
(
keys_output
+
begin_offset
,
keys
,
num_items
,
storage
.
keys_store
);
if
(
with_values
)
{
::
rocprim
::
wave_barrier
();
values_store_type
().
store
(
values_output
+
begin_offset
,
values
,
num_items
,
storage
.
values_store
);
}
}
template
<
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
sort
(
KeysInputIterator
keys_input
,
key_type
*
keys_tmp
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
value_type
*
values_tmp
,
ValuesOutputIterator
values_output
,
bool
to_output
,
unsigned
int
begin_offset
,
unsigned
int
end_offset
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
,
storage_type
&
storage
)
{
if
(
to_output
)
{
sort
(
keys_input
,
keys_output
,
values_input
,
values_output
,
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
storage
);
}
else
{
sort
(
keys_input
,
keys_tmp
,
values_input
,
values_tmp
,
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
storage
);
}
}
};
template
<
class
Config
,
bool
Descending
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
,
class
OffsetIterator
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
segmented_sort
(
KeysInputIterator
keys_input
,
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
*
keys_tmp
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
*
values_tmp
,
ValuesOutputIterator
values_output
,
bool
to_output
,
OffsetIterator
begin_offsets
,
OffsetIterator
end_offsets
,
unsigned
int
long_iterations
,
unsigned
int
short_iterations
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
)
{
constexpr
unsigned
int
long_radix_bits
=
Config
::
long_radix_bits
;
constexpr
unsigned
int
short_radix_bits
=
Config
::
short_radix_bits
;
constexpr
unsigned
int
block_size
=
Config
::
sort
::
block_size
;
constexpr
unsigned
int
items_per_thread
=
Config
::
sort
::
items_per_thread
;
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
constexpr
bool
warp_sort_enabled
=
Config
::
warp_sort_config
::
enable_unpartitioned_warp_sort
;
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
value_type
=
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
;
using
single_block_helper_type
=
segmented_radix_sort_single_block_helper
<
key_type
,
value_type
,
block_size
,
items_per_thread
,
Descending
>
;
using
long_radix_helper_type
=
segmented_radix_sort_helper
<
key_type
,
value_type
,
::
rocprim
::
device_warp_size
(),
block_size
,
items_per_thread
,
long_radix_bits
,
Descending
>
;
using
short_radix_helper_type
=
segmented_radix_sort_helper
<
key_type
,
value_type
,
::
rocprim
::
device_warp_size
(),
block_size
,
items_per_thread
,
short_radix_bits
,
Descending
>
;
using
warp_sort_helper_type
=
segmented_warp_sort_helper
<
select_warp_sort_helper_config_small_t
<
typename
Config
::
warp_sort_config
>
,
key_type
,
value_type
,
Descending
>
;
static
constexpr
unsigned
int
items_per_warp
=
warp_sort_helper_type
::
items_per_warp
;
ROCPRIM_SHARED_MEMORY
union
{
typename
single_block_helper_type
::
storage_type
single_block_helper
;
typename
long_radix_helper_type
::
storage_type
long_radix_helper
;
typename
short_radix_helper_type
::
storage_type
short_radix_helper
;
typename
warp_sort_helper_type
::
storage_type
warp_sort_helper
;
}
storage
;
const
unsigned
int
segment_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
begin_offset
=
begin_offsets
[
segment_id
];
const
unsigned
int
end_offset
=
end_offsets
[
segment_id
];
// Empty segment
if
(
end_offset
<=
begin_offset
)
{
return
;
}
if
(
end_offset
-
begin_offset
>
items_per_block
)
{
// Large segment
unsigned
int
bit
=
begin_bit
;
for
(
unsigned
int
i
=
0
;
i
<
long_iterations
;
i
++
)
{
long_radix_helper_type
().
sort
(
keys_input
,
keys_tmp
,
keys_output
,
values_input
,
values_tmp
,
values_output
,
to_output
,
begin_offset
,
end_offset
,
bit
,
begin_bit
,
end_bit
,
storage
.
long_radix_helper
);
to_output
=
!
to_output
;
bit
+=
long_radix_bits
;
}
for
(
unsigned
int
i
=
0
;
i
<
short_iterations
;
i
++
)
{
short_radix_helper_type
().
sort
(
keys_input
,
keys_tmp
,
keys_output
,
values_input
,
values_tmp
,
values_output
,
to_output
,
begin_offset
,
end_offset
,
bit
,
begin_bit
,
end_bit
,
storage
.
short_radix_helper
);
to_output
=
!
to_output
;
bit
+=
short_radix_bits
;
}
}
else
if
(
!
warp_sort_enabled
||
end_offset
-
begin_offset
>
items_per_warp
)
{
// Small segment
single_block_helper_type
().
sort
(
keys_input
,
keys_tmp
,
keys_output
,
values_input
,
values_tmp
,
values_output
,
((
long_iterations
+
short_iterations
)
%
2
==
0
)
!=
to_output
,
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
storage
.
single_block_helper
);
}
else
if
(
::
rocprim
::
flat_block_thread_id
()
<
Config
::
warp_sort_config
::
logical_warp_size_small
)
{
// Single warp segment
warp_sort_helper_type
().
sort
(
keys_input
,
keys_tmp
,
keys_output
,
values_input
,
values_tmp
,
values_output
,
((
long_iterations
+
short_iterations
)
%
2
==
0
)
!=
to_output
,
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
storage
.
warp_sort_helper
);
}
}
template
<
class
Config
,
bool
Descending
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
,
class
SegmentIndexIterator
,
class
OffsetIterator
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
segmented_sort_large
(
KeysInputIterator
keys_input
,
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
*
keys_tmp
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
*
values_tmp
,
ValuesOutputIterator
values_output
,
bool
to_output
,
SegmentIndexIterator
segment_indices
,
OffsetIterator
begin_offsets
,
OffsetIterator
end_offsets
,
unsigned
int
long_iterations
,
unsigned
int
short_iterations
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
)
{
constexpr
unsigned
int
long_radix_bits
=
Config
::
long_radix_bits
;
constexpr
unsigned
int
short_radix_bits
=
Config
::
short_radix_bits
;
constexpr
unsigned
int
block_size
=
Config
::
sort
::
block_size
;
constexpr
unsigned
int
items_per_thread
=
Config
::
sort
::
items_per_thread
;
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
value_type
=
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
;
using
single_block_helper_type
=
segmented_radix_sort_single_block_helper
<
key_type
,
value_type
,
block_size
,
items_per_thread
,
Descending
>
;
using
long_radix_helper_type
=
segmented_radix_sort_helper
<
key_type
,
value_type
,
::
rocprim
::
device_warp_size
(),
block_size
,
items_per_thread
,
long_radix_bits
,
Descending
>
;
using
short_radix_helper_type
=
segmented_radix_sort_helper
<
key_type
,
value_type
,
::
rocprim
::
device_warp_size
(),
block_size
,
items_per_thread
,
short_radix_bits
,
Descending
>
;
ROCPRIM_SHARED_MEMORY
union
{
typename
single_block_helper_type
::
storage_type
single_block_helper
;
typename
long_radix_helper_type
::
storage_type
long_radix_helper
;
typename
short_radix_helper_type
::
storage_type
short_radix_helper
;
}
storage
;
const
unsigned
int
block_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
segment_id
=
segment_indices
[
block_id
];
const
unsigned
int
begin_offset
=
begin_offsets
[
segment_id
];
const
unsigned
int
end_offset
=
end_offsets
[
segment_id
];
if
(
end_offset
<=
begin_offset
)
{
return
;
}
if
(
end_offset
-
begin_offset
>
items_per_block
)
{
unsigned
int
bit
=
begin_bit
;
for
(
unsigned
int
i
=
0
;
i
<
long_iterations
;
i
++
)
{
long_radix_helper_type
().
sort
(
keys_input
,
keys_tmp
,
keys_output
,
values_input
,
values_tmp
,
values_output
,
to_output
,
begin_offset
,
end_offset
,
bit
,
begin_bit
,
end_bit
,
storage
.
long_radix_helper
);
to_output
=
!
to_output
;
bit
+=
long_radix_bits
;
}
for
(
unsigned
int
i
=
0
;
i
<
short_iterations
;
i
++
)
{
short_radix_helper_type
().
sort
(
keys_input
,
keys_tmp
,
keys_output
,
values_input
,
values_tmp
,
values_output
,
to_output
,
begin_offset
,
end_offset
,
bit
,
begin_bit
,
end_bit
,
storage
.
short_radix_helper
);
to_output
=
!
to_output
;
bit
+=
short_radix_bits
;
}
}
else
{
single_block_helper_type
().
sort
(
keys_input
,
keys_tmp
,
keys_output
,
values_input
,
values_tmp
,
values_output
,
((
long_iterations
+
short_iterations
)
%
2
==
0
)
!=
to_output
,
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
storage
.
single_block_helper
);
}
}
template
<
class
Config
,
bool
Descending
,
class
KeysInputIterator
,
class
KeysOutputIterator
,
class
ValuesInputIterator
,
class
ValuesOutputIterator
,
class
SegmentIndexIterator
,
class
OffsetIterator
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
segmented_sort_small
(
KeysInputIterator
keys_input
,
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
*
keys_tmp
,
KeysOutputIterator
keys_output
,
ValuesInputIterator
values_input
,
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
*
values_tmp
,
ValuesOutputIterator
values_output
,
bool
to_output
,
unsigned
int
num_segments
,
SegmentIndexIterator
segment_indices
,
OffsetIterator
begin_offsets
,
OffsetIterator
end_offsets
,
unsigned
int
begin_bit
,
unsigned
int
end_bit
)
{
static
constexpr
unsigned
int
block_size
=
Config
::
block_size
;
static
constexpr
unsigned
int
logical_warp_size
=
Config
::
logical_warp_size
;
static_assert
(
block_size
%
logical_warp_size
==
0
,
"logical_warp_size must be a divisor of block_size"
);
static
constexpr
unsigned
int
warps_per_block
=
block_size
/
logical_warp_size
;
using
key_type
=
typename
std
::
iterator_traits
<
KeysInputIterator
>::
value_type
;
using
value_type
=
typename
std
::
iterator_traits
<
ValuesInputIterator
>::
value_type
;
using
warp_sort_helper_type
=
segmented_warp_sort_helper
<
Config
,
key_type
,
value_type
,
Descending
>
;
ROCPRIM_SHARED_MEMORY
typename
warp_sort_helper_type
::
storage_type
storage
;
const
unsigned
int
block_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
logical_warp_id
=
::
rocprim
::
detail
::
logical_warp_id
<
logical_warp_size
>
();
const
unsigned
int
segment_index
=
block_id
*
warps_per_block
+
logical_warp_id
;
if
(
segment_index
>=
num_segments
)
{
return
;
}
const
unsigned
int
segment_id
=
segment_indices
[
segment_index
];
const
unsigned
int
begin_offset
=
begin_offsets
[
segment_id
];
const
unsigned
int
end_offset
=
end_offsets
[
segment_id
];
if
(
end_offset
<=
begin_offset
)
{
return
;
}
warp_sort_helper_type
().
sort
(
keys_input
,
keys_tmp
,
keys_output
,
values_input
,
values_tmp
,
values_output
,
to_output
,
begin_offset
,
end_offset
,
begin_bit
,
end_bit
,
storage
);
}
}
// end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_
3rdparty/cub/rocprim/device/detail/device_segmented_reduce.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_REDUCE_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_REDUCE_HPP_
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../types.hpp"
#include "../../block/block_load_func.hpp"
#include "../../block/block_reduce.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
template
<
class
Config
,
class
InputIterator
,
class
OutputIterator
,
class
OffsetIterator
,
class
ResultType
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
segmented_reduce
(
InputIterator
input
,
OutputIterator
output
,
OffsetIterator
begin_offsets
,
OffsetIterator
end_offsets
,
BinaryFunction
reduce_op
,
ResultType
initial_value
)
{
constexpr
unsigned
int
block_size
=
Config
::
block_size
;
constexpr
unsigned
int
items_per_thread
=
Config
::
items_per_thread
;
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
using
reduce_type
=
::
rocprim
::
block_reduce
<
ResultType
,
block_size
,
Config
::
block_reduce_method
>
;
ROCPRIM_SHARED_MEMORY
typename
reduce_type
::
storage_type
reduce_storage
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
segment_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
begin_offset
=
begin_offsets
[
segment_id
];
const
unsigned
int
end_offset
=
end_offsets
[
segment_id
];
// Empty segment
if
(
end_offset
<=
begin_offset
)
{
if
(
flat_id
==
0
)
{
output
[
segment_id
]
=
initial_value
;
}
return
;
}
ResultType
result
;
unsigned
int
block_offset
=
begin_offset
;
if
(
block_offset
+
items_per_block
>
end_offset
)
{
// Segment is shorter than items_per_block
// Load the partial block and reduce the current thread's values
const
unsigned
int
valid_count
=
end_offset
-
block_offset
;
if
(
flat_id
<
valid_count
)
{
unsigned
int
offset
=
block_offset
+
flat_id
;
result
=
input
[
offset
];
offset
+=
block_size
;
while
(
offset
<
end_offset
)
{
result
=
reduce_op
(
result
,
static_cast
<
ResultType
>
(
input
[
offset
]));
offset
+=
block_size
;
}
}
// Reduce threads' reductions to compute the final result
if
(
valid_count
>=
block_size
)
{
// All threads have at least one value, i.e. result has valid value
reduce_type
().
reduce
(
result
,
result
,
reduce_storage
,
reduce_op
);
}
else
{
reduce_type
().
reduce
(
result
,
result
,
valid_count
,
reduce_storage
,
reduce_op
);
}
}
else
{
// Long segments
ResultType
values
[
items_per_thread
];
// Load the first block and reduce the current thread's values
block_load_direct_striped
<
block_size
>
(
flat_id
,
input
+
block_offset
,
values
);
result
=
values
[
0
];
for
(
unsigned
int
i
=
1
;
i
<
items_per_thread
;
i
++
)
{
result
=
reduce_op
(
result
,
values
[
i
]);
}
block_offset
+=
items_per_block
;
// Load next full blocks and continue reduction
while
(
block_offset
+
items_per_block
<
end_offset
)
{
block_load_direct_striped
<
block_size
>
(
flat_id
,
input
+
block_offset
,
values
);
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
i
++
)
{
result
=
reduce_op
(
result
,
values
[
i
]);
}
block_offset
+=
items_per_block
;
}
// Load the last (probably partial) block and continue reduction
const
unsigned
int
valid_count
=
end_offset
-
block_offset
;
block_load_direct_striped
<
block_size
>
(
flat_id
,
input
+
block_offset
,
values
,
valid_count
);
for
(
unsigned
int
i
=
0
;
i
<
items_per_thread
;
i
++
)
{
if
(
i
*
block_size
+
flat_id
<
valid_count
)
{
result
=
reduce_op
(
result
,
values
[
i
]);
}
}
// Reduce threads' reductions to compute the final result
reduce_type
().
reduce
(
result
,
result
,
reduce_storage
,
reduce_op
);
}
if
(
flat_id
==
0
)
{
output
[
segment_id
]
=
reduce_op
(
initial_value
,
result
);
}
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_REDUCE_HPP_
3rdparty/cub/rocprim/device/detail/device_segmented_scan.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_SCAN_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_SCAN_HPP_
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../intrinsics.hpp"
#include "../../types.hpp"
#include "../../detail/various.hpp"
#include "../../detail/binary_op_wrappers.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_store.hpp"
#include "../../block/block_scan.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
template
<
bool
Exclusive
,
bool
UsePrefix
,
class
BlockScanType
,
class
T
,
unsigned
int
ItemsPerThread
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
segmented_scan_block_scan
(
T
(
&
input
)[
ItemsPerThread
],
T
(
&
output
)[
ItemsPerThread
],
T
&
prefix
,
typename
BlockScanType
::
storage_type
&
storage
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<
Exclusive
>::
type
{
auto
prefix_op
=
[
&
prefix
,
scan_op
](
const
T
&
reduction
)
{
auto
saved_prefix
=
prefix
;
prefix
=
scan_op
(
prefix
,
reduction
);
return
saved_prefix
;
};
BlockScanType
()
.
exclusive_scan
(
input
,
output
,
storage
,
prefix_op
,
scan_op
);
}
template
<
bool
Exclusive
,
bool
UsePrefix
,
class
BlockScanType
,
class
T
,
unsigned
int
ItemsPerThread
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
auto
segmented_scan_block_scan
(
T
(
&
input
)[
ItemsPerThread
],
T
(
&
output
)[
ItemsPerThread
],
T
&
prefix
,
typename
BlockScanType
::
storage_type
&
storage
,
BinaryFunction
scan_op
)
->
typename
std
::
enable_if
<!
Exclusive
>::
type
{
if
(
UsePrefix
)
{
auto
prefix_op
=
[
&
prefix
,
scan_op
](
const
T
&
reduction
)
{
auto
saved_prefix
=
prefix
;
prefix
=
scan_op
(
prefix
,
reduction
);
return
saved_prefix
;
};
BlockScanType
()
.
inclusive_scan
(
input
,
output
,
storage
,
prefix_op
,
scan_op
);
return
;
}
BlockScanType
()
.
inclusive_scan
(
input
,
output
,
prefix
,
storage
,
scan_op
);
}
template
<
bool
Exclusive
,
class
Config
,
class
ResultType
,
class
InputIterator
,
class
OutputIterator
,
class
OffsetIterator
,
class
InitValueType
,
class
BinaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_FORCE_INLINE
void
segmented_scan
(
InputIterator
input
,
OutputIterator
output
,
OffsetIterator
begin_offsets
,
OffsetIterator
end_offsets
,
InitValueType
initial_value
,
BinaryFunction
scan_op
)
{
constexpr
auto
block_size
=
Config
::
block_size
;
constexpr
auto
items_per_thread
=
Config
::
items_per_thread
;
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
using
result_type
=
ResultType
;
using
block_load_type
=
::
rocprim
::
block_load
<
result_type
,
block_size
,
items_per_thread
,
Config
::
block_load_method
>
;
using
block_store_type
=
::
rocprim
::
block_store
<
result_type
,
block_size
,
items_per_thread
,
Config
::
block_store_method
>
;
using
block_scan_type
=
::
rocprim
::
block_scan
<
result_type
,
block_size
,
Config
::
block_scan_method
>
;
ROCPRIM_SHARED_MEMORY
union
{
typename
block_load_type
::
storage_type
load
;
typename
block_store_type
::
storage_type
store
;
typename
block_scan_type
::
storage_type
scan
;
}
storage
;
const
unsigned
int
segment_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
begin_offset
=
begin_offsets
[
segment_id
];
const
unsigned
int
end_offset
=
end_offsets
[
segment_id
];
// Empty segment
if
(
end_offset
<=
begin_offset
)
{
return
;
}
// Input values
result_type
values
[
items_per_thread
];
result_type
prefix
=
initial_value
;
unsigned
int
block_offset
=
begin_offset
;
if
(
block_offset
+
items_per_block
>
end_offset
)
{
// Segment is shorter than items_per_block
// Load the partial block
const
unsigned
int
valid_count
=
end_offset
-
block_offset
;
block_load_type
().
load
(
input
+
block_offset
,
values
,
valid_count
,
storage
.
load
);
::
rocprim
::
syncthreads
();
// Perform scan operation
segmented_scan_block_scan
<
Exclusive
,
false
,
block_scan_type
>
(
values
,
values
,
prefix
,
storage
.
scan
,
scan_op
);
::
rocprim
::
syncthreads
();
// Store the partial block
block_store_type
().
store
(
output
+
block_offset
,
values
,
valid_count
,
storage
.
store
);
}
else
{
// Long segments
// Load the first block of input values
block_load_type
().
load
(
input
+
block_offset
,
values
,
storage
.
load
);
::
rocprim
::
syncthreads
();
// Perform scan operation
segmented_scan_block_scan
<
Exclusive
,
false
,
block_scan_type
>
(
values
,
values
,
prefix
,
storage
.
scan
,
scan_op
);
::
rocprim
::
syncthreads
();
// Store
block_store_type
().
store
(
output
+
block_offset
,
values
,
storage
.
store
);
::
rocprim
::
syncthreads
();
block_offset
+=
items_per_block
;
// Load next full blocks and continue scanning
while
(
block_offset
+
items_per_block
<
end_offset
)
{
block_load_type
().
load
(
input
+
block_offset
,
values
,
storage
.
load
);
::
rocprim
::
syncthreads
();
// Perform scan operation
segmented_scan_block_scan
<
Exclusive
,
true
,
block_scan_type
>
(
values
,
values
,
prefix
,
storage
.
scan
,
scan_op
);
::
rocprim
::
syncthreads
();
block_store_type
().
store
(
output
+
block_offset
,
values
,
storage
.
store
);
::
rocprim
::
syncthreads
();
block_offset
+=
items_per_block
;
}
// Load the last (probably partial) block and continue scanning
const
unsigned
int
valid_count
=
end_offset
-
block_offset
;
block_load_type
().
load
(
input
+
block_offset
,
values
,
valid_count
,
storage
.
load
);
::
rocprim
::
syncthreads
();
// Perform scan operation
segmented_scan_block_scan
<
Exclusive
,
true
,
block_scan_type
>
(
values
,
values
,
prefix
,
storage
.
scan
,
scan_op
);
::
rocprim
::
syncthreads
();
// Store the partial block
block_store_type
().
store
(
output
+
block_offset
,
values
,
valid_count
,
storage
.
store
);
}
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_REDUCE_HPP_
3rdparty/cub/rocprim/device/detail/device_transform.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../detail/match_result_type.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../types.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_store.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
// Wrapper for unpacking tuple to be used with BinaryFunction.
// See transform function which accepts two input iterators.
template
<
class
T1
,
class
T2
,
class
BinaryFunction
>
struct
unpack_binary_op
{
using
result_type
=
typename
::
rocprim
::
detail
::
invoke_result
<
BinaryFunction
,
T1
,
T2
>::
type
;
ROCPRIM_HOST_DEVICE
inline
unpack_binary_op
()
=
default
;
ROCPRIM_HOST_DEVICE
inline
unpack_binary_op
(
BinaryFunction
binary_op
)
:
binary_op_
(
binary_op
)
{
}
ROCPRIM_HOST_DEVICE
inline
~
unpack_binary_op
()
=
default
;
ROCPRIM_HOST_DEVICE
inline
result_type
operator
()(
const
::
rocprim
::
tuple
<
T1
,
T2
>&
t
)
{
return
binary_op_
(
::
rocprim
::
get
<
0
>
(
t
),
::
rocprim
::
get
<
1
>
(
t
));
}
private:
BinaryFunction
binary_op_
;
};
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
class
ResultType
,
class
InputIterator
,
class
OutputIterator
,
class
UnaryFunction
>
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
transform_kernel_impl
(
InputIterator
input
,
const
size_t
input_size
,
OutputIterator
output
,
UnaryFunction
transform_op
)
{
using
input_type
=
typename
std
::
iterator_traits
<
InputIterator
>::
value_type
;
using
output_type
=
typename
std
::
iterator_traits
<
OutputIterator
>::
value_type
;
using
result_type
=
typename
std
::
conditional
<
std
::
is_void
<
output_type
>::
value
,
ResultType
,
output_type
>::
type
;
constexpr
unsigned
int
items_per_block
=
BlockSize
*
ItemsPerThread
;
const
unsigned
int
flat_id
=
::
rocprim
::
detail
::
block_thread_id
<
0
>
();
const
unsigned
int
flat_block_id
=
::
rocprim
::
detail
::
block_id
<
0
>
();
const
unsigned
int
block_offset
=
flat_block_id
*
items_per_block
;
const
unsigned
int
number_of_blocks
=
::
rocprim
::
detail
::
grid_size
<
0
>
();
const
unsigned
int
valid_in_last_block
=
input_size
-
block_offset
;
input_type
input_values
[
ItemsPerThread
];
result_type
output_values
[
ItemsPerThread
];
if
(
flat_block_id
==
(
number_of_blocks
-
1
))
// last block
{
block_load_direct_striped
<
BlockSize
>
(
flat_id
,
input
+
block_offset
,
input_values
,
valid_in_last_block
);
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
if
(
BlockSize
*
i
+
flat_id
<
valid_in_last_block
)
{
output_values
[
i
]
=
transform_op
(
input_values
[
i
]);
}
}
block_store_direct_striped
<
BlockSize
>
(
flat_id
,
output
+
block_offset
,
output_values
,
valid_in_last_block
);
}
else
{
block_load_direct_striped
<
BlockSize
>
(
flat_id
,
input
+
block_offset
,
input_values
);
ROCPRIM_UNROLL
for
(
unsigned
int
i
=
0
;
i
<
ItemsPerThread
;
i
++
)
{
output_values
[
i
]
=
transform_op
(
input_values
[
i
]);
}
block_store_direct_striped
<
BlockSize
>
(
flat_id
,
output
+
block_offset
,
output_values
);
}
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_
3rdparty/cub/rocprim/device/detail/lookback_scan_state.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_LOOKBACK_SCAN_STATE_HPP_
#define ROCPRIM_DEVICE_DETAIL_LOOKBACK_SCAN_STATE_HPP_
#include <type_traits>
#include "../../intrinsics.hpp"
#include "../../types.hpp"
#include "../../type_traits.hpp"
#include "../../warp/detail/warp_reduce_crosslane.hpp"
#include "../../warp/detail/warp_scan_crosslane.hpp"
#include "../../detail/various.hpp"
#include "../../detail/binary_op_wrappers.hpp"
extern
"C"
{
void
__builtin_amdgcn_s_sleep
(
int
);
}
BEGIN_ROCPRIM_NAMESPACE
// Single pass prefix scan was implemented based on:
// Merrill, D. and Garland, M. Single-pass Parallel Prefix Scan with Decoupled Look-back.
// Technical Report NVR2016-001, NVIDIA Research. Mar. 2016.
namespace
detail
{
enum
prefix_flag
{
// flag for padding, values should be discarded
PREFIX_INVALID
=
-
1
,
// initialized, not result in value
PREFIX_EMPTY
=
0
,
// partial prefix value (from single block)
PREFIX_PARTIAL
=
1
,
// final prefix value
PREFIX_COMPLETE
=
2
};
// lookback_scan_state object keeps track of prefixes status for
// a look-back prefix scan. Initially every prefix can be either
// invalid (padding values) or empty. One thread in a block should
// later set it to partial, and later to complete.
template
<
class
T
,
bool
UseSleep
=
false
,
bool
IsSmall
=
(
sizeof
(
T
)
<=
4
)>
struct
lookback_scan_state
;
// Packed flag and prefix value are loaded/stored in one atomic operation.
template
<
class
T
,
bool
UseSleep
>
struct
lookback_scan_state
<
T
,
UseSleep
,
true
>
{
private:
using
flag_type_
=
char
;
// Type which is used in store/load operations of block prefix (flag and value).
// It is 32-bit or 64-bit int and can be loaded/stored using single atomic instruction.
using
prefix_underlying_type
=
typename
std
::
conditional
<
(
sizeof
(
T
)
>
2
),
unsigned
long
long
,
unsigned
int
>::
type
;
// Helper struct
struct
alignas
(
sizeof
(
prefix_underlying_type
))
prefix_type
{
flag_type_
flag
;
T
value
;
};
static_assert
(
sizeof
(
prefix_underlying_type
)
==
sizeof
(
prefix_type
),
""
);
public:
// Type used for flag/flag of block prefix
using
flag_type
=
flag_type_
;
using
value_type
=
T
;
// temp_storage must point to allocation of get_storage_size(number_of_blocks) bytes
ROCPRIM_HOST
static
inline
lookback_scan_state
create
(
void
*
temp_storage
,
const
unsigned
int
number_of_blocks
)
{
(
void
)
number_of_blocks
;
lookback_scan_state
state
;
state
.
prefixes
=
reinterpret_cast
<
prefix_underlying_type
*>
(
temp_storage
);
return
state
;
}
ROCPRIM_HOST
static
inline
size_t
get_storage_size
(
const
unsigned
int
number_of_blocks
)
{
return
sizeof
(
prefix_underlying_type
)
*
(
::
rocprim
::
host_warp_size
()
+
number_of_blocks
);
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
initialize_prefix
(
const
unsigned
int
block_id
,
const
unsigned
int
number_of_blocks
)
{
constexpr
unsigned
int
padding
=
::
rocprim
::
device_warp_size
();
if
(
block_id
<
number_of_blocks
)
{
prefix_type
prefix
;
prefix
.
flag
=
PREFIX_EMPTY
;
prefix_underlying_type
p
;
#ifndef __HIP_CPU_RT__
__builtin_memcpy
(
&
p
,
&
prefix
,
sizeof
(
prefix_type
));
#else
std
::
memcpy
(
&
p
,
&
prefix
,
sizeof
(
prefix_type
));
#endif
prefixes
[
padding
+
block_id
]
=
p
;
}
if
(
block_id
<
padding
)
{
prefix_type
prefix
;
prefix
.
flag
=
PREFIX_INVALID
;
prefix_underlying_type
p
;
#ifndef __HIP_CPU_RT__
__builtin_memcpy
(
&
p
,
&
prefix
,
sizeof
(
prefix_type
));
#else
std
::
memcpy
(
&
p
,
&
prefix
,
sizeof
(
prefix_type
));
#endif
prefixes
[
block_id
]
=
p
;
}
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
set_partial
(
const
unsigned
int
block_id
,
const
T
value
)
{
this
->
set
(
block_id
,
PREFIX_PARTIAL
,
value
);
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
set_complete
(
const
unsigned
int
block_id
,
const
T
value
)
{
this
->
set
(
block_id
,
PREFIX_COMPLETE
,
value
);
}
// block_id must be > 0
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
get
(
const
unsigned
int
block_id
,
flag_type
&
flag
,
T
&
value
)
{
constexpr
unsigned
int
padding
=
::
rocprim
::
device_warp_size
();
prefix_type
prefix
;
const
unsigned
int
SLEEP_MAX
=
32
;
unsigned
int
times_through
=
1
;
prefix_underlying_type
p
=
::
rocprim
::
detail
::
atomic_add
(
&
prefixes
[
padding
+
block_id
],
0
);
#ifndef __HIP_CPU_RT__
__builtin_memcpy
(
&
prefix
,
&
p
,
sizeof
(
prefix_type
));
#else
std
::
memcpy
(
&
prefix
,
&
p
,
sizeof
(
prefix_type
));
#endif
while
(
prefix
.
flag
==
PREFIX_EMPTY
)
{
if
(
UseSleep
)
{
for
(
unsigned
int
j
=
0
;
j
<
times_through
;
j
++
)
#ifndef __HIP_CPU_RT__
__builtin_amdgcn_s_sleep
(
1
);
#else
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
microseconds
{
1
});
#endif
if
(
times_through
<
SLEEP_MAX
)
times_through
++
;
}
// atomic_add(..., 0) is used to load values atomically
prefix_underlying_type
p
=
::
rocprim
::
detail
::
atomic_add
(
&
prefixes
[
padding
+
block_id
],
0
);
#ifndef __HIP_CPU_RT__
__builtin_memcpy
(
&
prefix
,
&
p
,
sizeof
(
prefix_type
));
#else
std
::
memcpy
(
&
prefix
,
&
p
,
sizeof
(
prefix_type
));
#endif
}
// return
flag
=
prefix
.
flag
;
value
=
prefix
.
value
;
}
private:
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
set
(
const
unsigned
int
block_id
,
const
flag_type
flag
,
const
T
value
)
{
constexpr
unsigned
int
padding
=
::
rocprim
::
device_warp_size
();
prefix_type
prefix
=
{
flag
,
value
};
prefix_underlying_type
p
;
#ifndef __HIP_CPU_RT__
__builtin_memcpy
(
&
p
,
&
prefix
,
sizeof
(
prefix_type
));
#else
std
::
memcpy
(
&
p
,
&
prefix
,
sizeof
(
prefix_type
));
#endif
::
rocprim
::
detail
::
atomic_exch
(
&
prefixes
[
padding
+
block_id
],
p
);
}
prefix_underlying_type
*
prefixes
;
};
// Flag, partial and final prefixes are stored in separate arrays.
// Consistency ensured by memory fences between flag and prefixes load/store operations.
template
<
class
T
,
bool
UseSleep
>
struct
lookback_scan_state
<
T
,
UseSleep
,
false
>
{
public:
using
flag_type
=
char
;
using
value_type
=
T
;
// temp_storage must point to allocation of get_storage_size(number_of_blocks) bytes
ROCPRIM_HOST
static
inline
lookback_scan_state
create
(
void
*
temp_storage
,
const
unsigned
int
number_of_blocks
)
{
const
auto
n
=
::
rocprim
::
host_warp_size
()
+
number_of_blocks
;
lookback_scan_state
state
;
auto
ptr
=
static_cast
<
char
*>
(
temp_storage
);
state
.
prefixes_flags
=
reinterpret_cast
<
flag_type
*>
(
ptr
);
ptr
+=
::
rocprim
::
detail
::
align_size
(
n
*
sizeof
(
flag_type
));
state
.
prefixes_partial_values
=
reinterpret_cast
<
T
*>
(
ptr
);
ptr
+=
::
rocprim
::
detail
::
align_size
(
n
*
sizeof
(
T
));
state
.
prefixes_complete_values
=
reinterpret_cast
<
T
*>
(
ptr
);
return
state
;
}
ROCPRIM_HOST
static
inline
size_t
get_storage_size
(
const
unsigned
int
number_of_blocks
)
{
const
auto
n
=
::
rocprim
::
host_warp_size
()
+
number_of_blocks
;
size_t
size
=
::
rocprim
::
detail
::
align_size
(
n
*
sizeof
(
flag_type
));
size
+=
2
*
::
rocprim
::
detail
::
align_size
(
n
*
sizeof
(
T
));
return
size
;
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
initialize_prefix
(
const
unsigned
int
block_id
,
const
unsigned
int
number_of_blocks
)
{
constexpr
unsigned
int
padding
=
::
rocprim
::
device_warp_size
();
if
(
block_id
<
number_of_blocks
)
{
prefixes_flags
[
padding
+
block_id
]
=
PREFIX_EMPTY
;
}
if
(
block_id
<
padding
)
{
prefixes_flags
[
block_id
]
=
PREFIX_INVALID
;
}
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
set_partial
(
const
unsigned
int
block_id
,
const
T
value
)
{
constexpr
unsigned
int
padding
=
::
rocprim
::
device_warp_size
();
store_volatile
(
&
prefixes_partial_values
[
padding
+
block_id
],
value
);
::
rocprim
::
detail
::
memory_fence_device
();
store_volatile
<
flag_type
>
(
&
prefixes_flags
[
padding
+
block_id
],
PREFIX_PARTIAL
);
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
set_complete
(
const
unsigned
int
block_id
,
const
T
value
)
{
constexpr
unsigned
int
padding
=
::
rocprim
::
device_warp_size
();
store_volatile
(
&
prefixes_complete_values
[
padding
+
block_id
],
value
);
::
rocprim
::
detail
::
memory_fence_device
();
store_volatile
<
flag_type
>
(
&
prefixes_flags
[
padding
+
block_id
],
PREFIX_COMPLETE
);
}
// block_id must be > 0
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
get
(
const
unsigned
int
block_id
,
flag_type
&
flag
,
T
&
value
)
{
constexpr
unsigned
int
padding
=
::
rocprim
::
device_warp_size
();
const
unsigned
int
SLEEP_MAX
=
32
;
unsigned
int
times_through
=
1
;
flag
=
load_volatile
(
&
prefixes_flags
[
padding
+
block_id
]);
::
rocprim
::
detail
::
memory_fence_device
();
while
(
flag
==
PREFIX_EMPTY
)
{
if
(
UseSleep
)
{
for
(
unsigned
int
j
=
0
;
j
<
times_through
;
j
++
)
#ifndef __HIP_CPU_RT__
__builtin_amdgcn_s_sleep
(
1
);
#else
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
microseconds
{
1
});
#endif
if
(
times_through
<
SLEEP_MAX
)
times_through
++
;
}
flag
=
load_volatile
(
&
prefixes_flags
[
padding
+
block_id
]);
::
rocprim
::
detail
::
memory_fence_device
();
}
if
(
flag
==
PREFIX_PARTIAL
)
value
=
load_volatile
(
&
prefixes_partial_values
[
padding
+
block_id
]);
else
value
=
load_volatile
(
&
prefixes_complete_values
[
padding
+
block_id
]);
}
private:
flag_type
*
prefixes_flags
;
// We need to separate arrays for partial and final prefixes, because
// value can be overwritten before flag is changed (flag and value are
// not stored in single instruction).
T
*
prefixes_partial_values
;
T
*
prefixes_complete_values
;
};
template
<
class
T
,
class
BinaryFunction
,
class
LookbackScanState
>
class
lookback_scan_prefix_op
{
using
flag_type
=
typename
LookbackScanState
::
flag_type
;
static_assert
(
std
::
is_same
<
T
,
typename
LookbackScanState
::
value_type
>::
value
,
"T must be LookbackScanState::value_type"
);
public:
ROCPRIM_DEVICE
ROCPRIM_INLINE
lookback_scan_prefix_op
(
unsigned
int
block_id
,
BinaryFunction
scan_op
,
LookbackScanState
&
scan_state
)
:
block_id_
(
block_id
),
scan_op_
(
scan_op
),
scan_state_
(
scan_state
)
{
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
~
lookback_scan_prefix_op
()
=
default
;
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
reduce_partial_prefixes
(
unsigned
int
block_id
,
flag_type
&
flag
,
T
&
partial_prefix
)
{
// Order of reduction must be reversed, because 0th thread has
// prefix from the (block_id_ - 1) block, 1st thread has prefix
// from (block_id_ - 2) block etc.
using
headflag_scan_op_type
=
reverse_binary_op_wrapper
<
BinaryFunction
,
T
,
T
>
;
using
warp_reduce_prefix_type
=
warp_reduce_crosslane
<
T
,
::
rocprim
::
device_warp_size
(),
false
>
;
T
block_prefix
;
scan_state_
.
get
(
block_id
,
flag
,
block_prefix
);
auto
headflag_scan_op
=
headflag_scan_op_type
(
scan_op_
);
warp_reduce_prefix_type
()
.
tail_segmented_reduce
(
block_prefix
,
partial_prefix
,
(
flag
==
PREFIX_COMPLETE
),
headflag_scan_op
);
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
T
get_prefix
()
{
flag_type
flag
;
T
partial_prefix
;
unsigned
int
previous_block_id
=
block_id_
-
::
rocprim
::
lane_id
()
-
1
;
// reduce last warp_size() number of prefixes to
// get the complete prefix for this block.
reduce_partial_prefixes
(
previous_block_id
,
flag
,
partial_prefix
);
T
prefix
=
partial_prefix
;
// while we don't load a complete prefix, reduce partial prefixes
while
(
::
rocprim
::
detail
::
warp_all
(
flag
!=
PREFIX_COMPLETE
))
{
previous_block_id
-=
::
rocprim
::
device_warp_size
();
reduce_partial_prefixes
(
previous_block_id
,
flag
,
partial_prefix
);
prefix
=
scan_op_
(
partial_prefix
,
prefix
);
}
return
prefix
;
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
T
operator
()(
T
reduction
)
{
// Set partial prefix for next block
if
(
::
rocprim
::
lane_id
()
==
0
)
{
scan_state_
.
set_partial
(
block_id_
,
reduction
);
}
// Get prefix
auto
prefix
=
get_prefix
();
// Set complete prefix for next block
if
(
::
rocprim
::
lane_id
()
==
0
)
{
scan_state_
.
set_complete
(
block_id_
,
scan_op_
(
prefix
,
reduction
));
}
return
prefix
;
}
protected:
unsigned
int
block_id_
;
BinaryFunction
scan_op_
;
LookbackScanState
&
scan_state_
;
};
inline
cudaError_t
is_sleep_scan_state_used
(
bool
&
use_sleep
)
{
cudaDeviceProp
prop
;
int
deviceId
;
if
(
const
cudaError_t
error
=
cudaGetDevice
(
&
deviceId
))
{
return
error
;
}
else
if
(
const
cudaError_t
error
=
cudaGetDeviceProperties
(
&
prop
,
deviceId
))
{
return
error
;
}
const
int
asicRevision
=
0
;
use_sleep
=
0
;
return
cudaSuccess
;
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_LOOKBACK_SCAN_STATE_HPP_
3rdparty/cub/rocprim/device/detail/ordered_block_id.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_ORDERED_BLOCK_ID_HPP_
#define ROCPRIM_DEVICE_DETAIL_ORDERED_BLOCK_ID_HPP_
#include <type_traits>
#include <limits>
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../types.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
// Helper struct for generating ordered unique ids for blocks in a grid.
template
<
class
T
/* id type */
=
unsigned
int
>
struct
ordered_block_id
{
static_assert
(
std
::
is_integral
<
T
>::
value
,
"T must be integer"
);
using
id_type
=
T
;
// shared memory temporary storage type
struct
storage_type
{
id_type
id
;
};
ROCPRIM_HOST
static
inline
ordered_block_id
create
(
id_type
*
id
)
{
ordered_block_id
ordered_id
;
ordered_id
.
id
=
id
;
return
ordered_id
;
}
ROCPRIM_HOST
static
inline
size_t
get_storage_size
()
{
return
sizeof
(
id_type
);
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
void
reset
()
{
*
id
=
static_cast
<
id_type
>
(
0
);
}
ROCPRIM_DEVICE
ROCPRIM_INLINE
id_type
get
(
unsigned
int
tid
,
storage_type
&
storage
)
{
if
(
tid
==
0
)
{
storage
.
id
=
::
rocprim
::
detail
::
atomic_add
(
this
->
id
,
1
);
}
::
rocprim
::
syncthreads
();
return
storage
.
id
;
}
id_type
*
id
;
};
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_ORDERED_BLOCK_ID_HPP_
3rdparty/cub/rocprim/device/detail/uint_fast_div.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_UINT_FAST_DIV_HPP_
#define ROCPRIM_DEVICE_DETAIL_UINT_FAST_DIV_HPP_
#include "../../config.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace
detail
{
struct
uint_fast_div
{
unsigned
int
magic
;
// Magic number
unsigned
int
shift
;
// shift amount
unsigned
int
add
;
// "add" indicator
ROCPRIM_HOST_DEVICE
inline
uint_fast_div
()
=
default
;
ROCPRIM_HOST_DEVICE
inline
uint_fast_div
(
unsigned
int
d
)
{
// Must have 1 <= d <= 2**32-1.
if
(
d
==
1
)
{
magic
=
0
;
shift
=
0
;
add
=
0
;
return
;
}
int
p
;
unsigned
int
p32
=
1
,
q
,
r
,
delta
;
add
=
0
;
// Initialize "add" indicator.
p
=
31
;
// Initialize p.
q
=
0x7FFFFFFF
/
d
;
// Initialize q = (2**p - 1)/d.
r
=
0x7FFFFFFF
-
q
*
d
;
// Init. r = rem(2**p - 1, d).
do
{
p
=
p
+
1
;
if
(
p
==
32
)
p32
=
1
;
// Set p32 = 2**(p-32).
else
p32
=
2
*
p32
;
if
(
r
+
1
>=
d
-
r
)
{
if
(
q
>=
0x7FFFFFFF
)
add
=
1
;
q
=
2
*
q
+
1
;
r
=
2
*
r
+
1
-
d
;
}
else
{
if
(
q
>=
0x80000000
)
add
=
1
;
q
=
2
*
q
;
r
=
2
*
r
+
1
;
}
delta
=
d
-
1
-
r
;
}
while
(
p
<
64
&&
p32
<
delta
);
magic
=
q
+
1
;
// Magic number and
shift
=
p
-
32
;
// shift amount
if
(
add
)
shift
--
;
}
};
ROCPRIM_HOST_DEVICE
inline
unsigned
int
operator
/
(
unsigned
int
n
,
const
uint_fast_div
&
divisor
)
{
if
(
divisor
.
magic
==
0
)
{
// Special case for 1
return
n
;
}
// Higher 32-bit of 64-bit multiplication
unsigned
int
q
=
(
static_cast
<
unsigned
long
long
>
(
divisor
.
magic
)
*
static_cast
<
unsigned
long
long
>
(
n
))
>>
32
;
if
(
divisor
.
add
)
{
q
=
((
n
-
q
)
>>
1
)
+
q
;
}
return
q
>>
divisor
.
shift
;
}
}
// end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_UINT_FAST_DIV_HPP_
3rdparty/cub/rocprim/device/device_adjacent_difference.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_
#define ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_
#include "detail/device_adjacent_difference.hpp"
#include "device_adjacent_difference_config.hpp"
#include "config_types.hpp"
#include "device_transform.hpp"
#include "../config.hpp"
#include "../functional.hpp"
#include "../detail/various.hpp"
#include "../iterator/counting_iterator.hpp"
#include "../iterator/transform_iterator.hpp"
#include <cuda_runtime.h>
#include <chrono>
#include <iostream>
#include <iterator>
#include <cstddef>
/// \file
///
/// Device level adjacent_difference parallel primitives
BEGIN_ROCPRIM_NAMESPACE
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \
{ \
auto _error = cudaGetLastError(); \
if(_error != cudaSuccess) \
return _error; \
if(debug_synchronous) \
{ \
std::cout << name << "(" << size << ")"; \
auto __error = cudaStreamSynchronize(stream); \
if(__error != cudaSuccess) \
return __error; \
auto _end = std::chrono::high_resolution_clock::now(); \
auto _d = std::chrono::duration_cast<std::chrono::duration<double>>(_end - start); \
std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \
} \
}
namespace
detail
{
template
<
typename
Config
,
bool
InPlace
,
bool
Right
,
typename
InputIt
,
typename
OutputIt
,
typename
BinaryFunction
>
void
ROCPRIM_KERNEL
__launch_bounds__
(
Config
::
block_size
)
adjacent_difference_kernel
(
const
InputIt
input
,
const
OutputIt
output
,
const
std
::
size_t
size
,
const
BinaryFunction
op
,
const
typename
std
::
iterator_traits
<
InputIt
>::
value_type
*
previous_values
,
const
std
::
size_t
starting_block
)
{
adjacent_difference_kernel_impl
<
Config
,
InPlace
,
Right
>
(
input
,
output
,
size
,
op
,
previous_values
,
starting_block
);
}
template
<
typename
Config
,
bool
InPlace
,
bool
Right
,
typename
InputIt
,
typename
OutputIt
,
typename
BinaryFunction
>
cudaError_t
adjacent_difference_impl
(
void
*
const
temporary_storage
,
std
::
size_t
&
storage_size
,
const
InputIt
input
,
const
OutputIt
output
,
const
std
::
size_t
size
,
const
BinaryFunction
op
,
const
cudaStream_t
stream
,
const
bool
debug_synchronous
)
{
using
value_type
=
typename
std
::
iterator_traits
<
InputIt
>::
value_type
;
using
config
=
detail
::
default_or_custom_config
<
Config
,
detail
::
default_adjacent_difference_config
<
ROCPRIM_TARGET_ARCH
,
value_type
>>
;
static
constexpr
unsigned
int
block_size
=
config
::
block_size
;
static
constexpr
unsigned
int
items_per_thread
=
config
::
items_per_thread
;
static
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
const
std
::
size_t
num_blocks
=
ceiling_div
(
size
,
items_per_block
);
if
(
temporary_storage
==
nullptr
)
{
if
(
InPlace
&&
num_blocks
>=
2
)
{
storage_size
=
align_size
((
num_blocks
-
1
)
*
sizeof
(
value_type
));
}
else
{
// Make sure user won't try to allocate 0 bytes memory, because
// cudaMalloc will return nullptr when size is zero.
storage_size
=
4
;
}
return
cudaSuccess
;
}
if
(
num_blocks
==
0
)
{
return
cudaSuccess
;
}
// Copy values before they are overwritten to use as tile predecessors/successors
// this is not dereferenced when the operation is not in place
auto
*
const
previous_values
=
static_cast
<
value_type
*>
(
temporary_storage
);
if
ROCPRIM_IF_CONSTEXPR
(
InPlace
)
{
// If doing left adjacent diff then the last item of each block is needed for the
// next block, otherwise the first item is needed for the previous block
static
constexpr
auto
offset
=
items_per_block
-
(
Right
?
0
:
1
);
const
auto
block_starts_iter
=
make_transform_iterator
(
rocprim
::
make_counting_iterator
(
std
::
size_t
{
0
}),
[
base
=
input
+
offset
](
std
::
size_t
i
)
{
return
base
[
i
*
items_per_block
];
});
const
cudaError_t
error
=
::
rocprim
::
transform
(
block_starts_iter
,
previous_values
,
num_blocks
-
1
,
rocprim
::
identity
<>
{},
stream
,
debug_synchronous
);
if
(
error
!=
cudaSuccess
)
{
return
error
;
}
}
static
constexpr
unsigned
int
size_limit
=
config
::
size_limit
;
static
constexpr
auto
number_of_blocks_limit
=
std
::
max
(
size_limit
/
items_per_block
,
1u
);
static
constexpr
auto
aligned_size_limit
=
number_of_blocks_limit
*
items_per_block
;
// Launch number_of_blocks_limit blocks while there is still at least as many blocks
// left as the limit
const
auto
number_of_launch
=
ceiling_div
(
size
,
aligned_size_limit
);
if
(
debug_synchronous
)
{
std
::
cout
<<
"----------------------------------
\n
"
;
std
::
cout
<<
"size: "
<<
size
<<
'\n'
;
std
::
cout
<<
"aligned_size_limit: "
<<
aligned_size_limit
<<
'\n'
;
std
::
cout
<<
"number_of_launch: "
<<
number_of_launch
<<
'\n'
;
std
::
cout
<<
"block_size: "
<<
block_size
<<
'\n'
;
std
::
cout
<<
"items_per_block: "
<<
items_per_block
<<
'\n'
;
std
::
cout
<<
"----------------------------------
\n
"
;
}
for
(
std
::
size_t
i
=
0
,
offset
=
0
;
i
<
number_of_launch
;
++
i
,
offset
+=
aligned_size_limit
)
{
const
auto
current_size
=
static_cast
<
unsigned
int
>
(
std
::
min
<
std
::
size_t
>
(
size
-
offset
,
aligned_size_limit
));
const
auto
current_blocks
=
ceiling_div
(
current_size
,
items_per_block
);
const
auto
starting_block
=
i
*
number_of_blocks_limit
;
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
start
;
if
(
debug_synchronous
)
{
std
::
cout
<<
"index: "
<<
i
<<
'\n'
;
std
::
cout
<<
"current_size: "
<<
current_size
<<
'\n'
;
std
::
cout
<<
"number of blocks: "
<<
current_blocks
<<
'\n'
;
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
adjacent_difference_kernel
<
config
,
InPlace
,
Right
><<<
dim3
(
current_blocks
),
dim3
(
block_size
),
0
,
stream
>>>
(
input
+
offset
,
output
+
offset
,
size
,
op
,
previous_values
+
starting_block
,
starting_block
);
ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR
(
"adjacent_difference_kernel"
,
current_size
,
start
);
}
return
cudaSuccess
;
}
}
// namespace detail
#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR
#endif // DOXYGEN_SHOULD_SKIP_THIS
/// \addtogroup devicemodule
/// @{
/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements
/// in device accessible memory. Writes the output to the position of the left item.
///
/// Copies the first item to the output then performs calls the supplied operator with each pair
/// of neighboring elements and writes its result to the location of the second element.
/// Equivalent to the following code
/// \code{.cpp}
/// output[0] = input[0];
/// for(std::size_t int i = 1; i < size; ++i)
/// {
/// output[i] = op(input[i], input[i - 1]);
/// }
/// \endcode
///
/// \tparam Config - [optional] configuration of the primitive. It can be
/// `adjacent_difference_config` or a class with the same members.
/// \tparam InputIt - [inferred] random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam OutputIt - [inferred] random-access iterator type of the output range. Must meet the
/// requirements of a C++ OutputIterator concept. It can be a simple pointer type.
/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to
/// consecutive items. The signature of the function should be equivalent to the following:
/// `U f(const T1& a, const T2& b)`. The signature does not need to have
/// `const &`, but function object must not modify the object passed to it
/// \param temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// `storage_size` and function returns without performing the scan operation
/// \param storage_size - reference to a size (in bytes) of `temporary_storage`
/// \param input - iterator to the input range
/// \param output - iterator to the output range, must have any overlap with input
/// \param size - number of items in the input
/// \param op - [optional] the binary operation to apply
/// \param stream - [optional] HIP stream object. Default is `0` (the default stream)
/// \param debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors and extra debugging info is printed to the
/// standard output. Default value is `false`
///
/// \return `cudaSuccess` (0) after successful scan, otherwise the HIP runtime error of
/// type `cudaError_t`
///
/// \par Example
/// \parblock
/// In this example a device-level adjacent_difference operation is performed on integer values.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp> //or <rocprim/device/device_adjacent_difference.hpp>
///
/// // custom binary function
/// auto binary_op =
/// [] __device__ (int a, int b) -> int
/// {
/// return a - b;
/// };
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// std::size_t size; // e.g., 8
/// int* input1; // e.g., [8, 7, 6, 5, 4, 3, 2, 1]
/// int* output; // empty array of 8 elements
///
/// std::size_t temporary_storage_size_bytes;
/// void* temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::adjacent_difference(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// input, output, size, binary_op
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // perform adjacent difference
/// rocprim::adjacent_difference(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// input, output, size, binary_op
/// );
/// // output: [8, 1, 1, 1, 1, 1, 1, 1]
/// \endcode
/// \endparblock
template
<
typename
Config
=
default_config
,
typename
InputIt
,
typename
OutputIt
,
typename
BinaryFunction
=
::
rocprim
::
minus
<
>
>
cudaError_t
adjacent_difference
(
void
*
const
temporary_storage
,
std
::
size_t
&
storage_size
,
const
InputIt
input
,
const
OutputIt
output
,
const
std
::
size_t
size
,
const
BinaryFunction
op
=
BinaryFunction
{},
const
cudaStream_t
stream
=
0
,
const
bool
debug_synchronous
=
false
)
{
static
constexpr
bool
in_place
=
false
;
static
constexpr
bool
right
=
false
;
return
detail
::
adjacent_difference_impl
<
Config
,
in_place
,
right
>
(
temporary_storage
,
storage_size
,
input
,
output
,
size
,
op
,
stream
,
debug_synchronous
);
}
/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements
/// in device accessible memory. Writes the output to the position of the left item in place.
///
/// Copies the first item to the output then performs calls the supplied operator with each pair
/// of neighboring elements and writes its result to the location of the second element.
/// Equivalent to the following code
/// \code{.cpp}
/// for(std::size_t int i = size - 1; i > 0; --i)
/// {
/// input[i] = op(input[i], input[i - 1]);
/// }
/// \endcode
///
/// \tparam Config - [optional] configuration of the primitive. It can be
/// `adjacent_difference_config` or a class with the same members.
/// \tparam InputIt - [inferred] random-access iterator type of the value range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to
/// consecutive items. The signature of the function should be equivalent to the following:
/// `U f(const T1& a, const T2& b)`. The signature does not need to have
/// `const &`, but function object must not modify the object passed to it
/// \param temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// `storage_size` and function returns without performing the scan operation
/// \param storage_size - reference to a size (in bytes) of `temporary_storage`
/// \param values - iterator to the range values, will be overwritten with the results
/// \param size - number of items in the input
/// \param op - [optional] the binary operation to apply
/// \param stream - [optional] HIP stream object. Default is `0` (the default stream)
/// \param debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors and extra debugging info is printed to the
/// standard output. Default value is `false`
///
/// \return `cudaSuccess` (0) after successful scan, otherwise the HIP runtime error of
/// type `cudaError_t`
template
<
typename
Config
=
default_config
,
typename
InputIt
,
typename
BinaryFunction
=
::
rocprim
::
minus
<
>
>
cudaError_t
adjacent_difference_inplace
(
void
*
const
temporary_storage
,
std
::
size_t
&
storage_size
,
const
InputIt
values
,
const
std
::
size_t
size
,
const
BinaryFunction
op
=
BinaryFunction
{},
const
cudaStream_t
stream
=
0
,
const
bool
debug_synchronous
=
false
)
{
static
constexpr
bool
in_place
=
true
;
static
constexpr
bool
right
=
false
;
return
detail
::
adjacent_difference_impl
<
Config
,
in_place
,
right
>
(
temporary_storage
,
storage_size
,
values
,
values
,
size
,
op
,
stream
,
debug_synchronous
);
}
/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements
/// in device accessible memory. Writes the output to the position of the right item.
///
/// Copies the last item to the output then performs calls the supplied operator with each pair
/// of neighboring elements and writes its result to the location of the first element.
/// Equivalent to the following code
/// \code{.cpp}
/// output[size - 1] = input[size - 1];
/// for(std::size_t int i = 0; i < size - 1; ++i)
/// {
/// output[i] = op(input[i], input[i + 1]);
/// }
/// \endcode
///
/// \tparam Config - [optional] configuration of the primitive. It can be
/// `adjacent_difference_config` or a class with the same members.
/// \tparam InputIt - [inferred] random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam OutputIt - [inferred] random-access iterator type of the output range. Must meet the
/// requirements of a C++ OutputIterator concept. It can be a simple pointer type.
/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to
/// consecutive items. The signature of the function should be equivalent to the following:
/// `U f(const T1& a, const T2& b)`. The signature does not need to have
/// `const &`, but function object must not modify the object passed to it
/// \param temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// `storage_size` and function returns without performing the scan operation
/// \param storage_size - reference to a size (in bytes) of `temporary_storage`
/// \param input - iterator to the input range
/// \param output - iterator to the output range, must have any overlap with input
/// \param size - number of items in the input
/// \param op - [optional] the binary operation to apply
/// \param stream - [optional] HIP stream object. Default is `0` (the default stream)
/// \param debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors and extra debugging info is printed to the
/// standard output. Default value is `false`
///
/// \return `cudaSuccess` (0) after successful scan, otherwise the HIP runtime error of
/// type `cudaError_t`
///
/// \par Example
/// \parblock
/// In this example a device-level adjacent_difference operation is performed on integer values.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp> //or <rocprim/device/device_adjacent_difference.hpp>
///
/// // custom binary function
/// auto binary_op =
/// [] __device__ (int a, int b) -> int
/// {
/// return a - b;
/// };
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// std::size_t size; // e.g., 8
/// int* input1; // e.g., [1, 2, 3, 4, 5, 6, 7, 8]
/// int* output; // empty array of 8 elements
///
/// std::size_t temporary_storage_size_bytes;
/// void* temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::adjacent_difference_right(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// input, output, size, binary_op
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // perform adjacent difference
/// rocprim::adjacent_difference_right(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// input, output, size, binary_op
/// );
/// // output: [1, 1, 1, 1, 1, 1, 1, 8]
/// \endcode
/// \endparblock
template
<
typename
Config
=
default_config
,
typename
InputIt
,
typename
OutputIt
,
typename
BinaryFunction
=
::
rocprim
::
minus
<
>
>
cudaError_t
adjacent_difference_right
(
void
*
const
temporary_storage
,
std
::
size_t
&
storage_size
,
const
InputIt
input
,
const
OutputIt
output
,
const
std
::
size_t
size
,
const
BinaryFunction
op
=
BinaryFunction
{},
const
cudaStream_t
stream
=
0
,
const
bool
debug_synchronous
=
false
)
{
static
constexpr
bool
in_place
=
false
;
static
constexpr
bool
right
=
true
;
return
detail
::
adjacent_difference_impl
<
Config
,
in_place
,
right
>
(
temporary_storage
,
storage_size
,
input
,
output
,
size
,
op
,
stream
,
debug_synchronous
);
}
/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements
/// in device accessible memory. Writes the output to the position of the right item in place.
///
/// Copies the last item to the output then performs calls the supplied operator with each pair
/// of neighboring elements and writes its result to the location of the first element.
/// Equivalent to the following code
/// \code{.cpp}
/// for(std::size_t int i = 0; i < size - 1; --i)
/// {
/// input[i] = op(input[i], input[i + 1]);
/// }
/// \endcode
///
/// \tparam Config - [optional] configuration of the primitive. It can be
/// `adjacent_difference_config` or a class with the same members.
/// \tparam InputIt - [inferred] random-access iterator type of the value range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to
/// consecutive items. The signature of the function should be equivalent to the following:
/// `U f(const T1& a, const T2& b)`. The signature does not need to have
/// `const &`, but function object must not modify the object passed to it
/// \param temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// `storage_size` and function returns without performing the scan operation
/// \param storage_size - reference to a size (in bytes) of `temporary_storage`
/// \param values - iterator to the range values, will be overwritten with the results
/// \param size - number of items in the input
/// \param op - [optional] the binary operation to apply
/// \param stream - [optional] HIP stream object. Default is `0` (the default stream)
/// \param debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors and extra debugging info is printed to the
/// standard output. Default value is `false`
///
/// \return `cudaSuccess` (0) after successful scan, otherwise the HIP runtime error of
/// type `cudaError_t`
template
<
typename
Config
=
default_config
,
typename
InputIt
,
typename
BinaryFunction
=
::
rocprim
::
minus
<
>
>
cudaError_t
adjacent_difference_right_inplace
(
void
*
const
temporary_storage
,
std
::
size_t
&
storage_size
,
const
InputIt
values
,
const
std
::
size_t
size
,
const
BinaryFunction
op
=
BinaryFunction
{},
const
cudaStream_t
stream
=
0
,
const
bool
debug_synchronous
=
false
)
{
static
constexpr
bool
in_place
=
true
;
static
constexpr
bool
right
=
true
;
return
detail
::
adjacent_difference_impl
<
Config
,
in_place
,
right
>
(
temporary_storage
,
storage_size
,
values
,
values
,
size
,
op
,
stream
,
debug_synchronous
);
}
/// @}
// end of group devicemodule
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_
3rdparty/cub/rocprim/device/device_adjacent_difference_config.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_CONFIG_HPP_
#define ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_CONFIG_HPP_
#include <type_traits>
#include "../config.hpp"
#include "../detail/various.hpp"
#include "../functional.hpp"
#include "config_types.hpp"
#include "../block/block_load.hpp"
#include "../block/block_store.hpp"
/// \addtogroup primitivesmodule_deviceconfigs
/// @{
BEGIN_ROCPRIM_NAMESPACE
/// \brief Configuration of device-level adjacent_difference primitives.
///
/// \tparam BlockSize - number of threads in a block.
/// \tparam ItemsPerThread - number of items processed by each thread
/// \tparam LoadMethod - method for loading input values
/// \tparam StoreMethod - method for storing values
/// \tparam SizeLimit - limit on the number of items for a single adjacent_difference kernel launch.
/// Larger input sizes will be broken up to multiple kernel launches.
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
block_load_method
LoadMethod
=
block_load_method
::
block_load_transpose
,
block_store_method
StoreMethod
=
block_store_method
::
block_store_transpose
,
unsigned
int
SizeLimit
=
ROCPRIM_GRID_SIZE_LIMIT
>
struct
adjacent_difference_config
:
kernel_config
<
BlockSize
,
ItemsPerThread
,
SizeLimit
>
{
static
constexpr
block_load_method
load_method
=
LoadMethod
;
static
constexpr
block_store_method
store_method
=
StoreMethod
;
};
namespace
detail
{
template
<
class
Value
>
struct
adjacent_difference_config_fallback
{
static
constexpr
unsigned
int
item_scale
=
::
rocprim
::
detail
::
ceiling_div
<
unsigned
int
>
(
sizeof
(
Value
),
sizeof
(
int
));
using
type
=
adjacent_difference_config
<
256
,
::
rocprim
::
max
(
1u
,
16u
/
item_scale
)
>
;
};
template
<
unsigned
int
TargetArch
,
class
Value
>
struct
default_adjacent_difference_config
:
select_arch
<
TargetArch
,
adjacent_difference_config_fallback
<
Value
>>
{
};
}
// end namespace detail
END_ROCPRIM_NAMESPACE
/// @}
// end of group primitivesmodule_deviceconfigs
#endif // ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_CONFIG_HPP_
3rdparty/cub/rocprim/device/device_binary_search.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HPP_
#define ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HPP_
#include <type_traits>
#include <iterator>
#include "../config.hpp"
#include "../detail/various.hpp"
#include "detail/device_binary_search.hpp"
#include "device_transform.hpp"
BEGIN_ROCPRIM_NAMESPACE
/// \addtogroup devicemodule
/// @{
namespace
detail
{
template
<
class
Config
,
class
HaystackIterator
,
class
NeedlesIterator
,
class
OutputIterator
,
class
SearchFunction
,
class
CompareFunction
>
inline
cudaError_t
binary_search
(
void
*
temporary_storage
,
size_t
&
storage_size
,
HaystackIterator
haystack
,
NeedlesIterator
needles
,
OutputIterator
output
,
size_t
haystack_size
,
size_t
needles_size
,
SearchFunction
search_op
,
CompareFunction
compare_op
,
cudaStream_t
stream
,
bool
debug_synchronous
)
{
using
value_type
=
typename
std
::
iterator_traits
<
NeedlesIterator
>::
value_type
;
if
(
temporary_storage
==
nullptr
)
{
// Make sure user won't try to allocate 0 bytes memory, otherwise
// user may again pass nullptr as temporary_storage
storage_size
=
4
;
return
cudaSuccess
;
}
return
transform
<
Config
>
(
needles
,
output
,
needles_size
,
[
haystack
,
haystack_size
,
search_op
,
compare_op
]
ROCPRIM_DEVICE
(
const
value_type
&
value
)
{
return
search_op
(
haystack
,
haystack_size
,
value
,
compare_op
);
},
stream
,
debug_synchronous
);
}
}
// end of detail namespace
template
<
class
Config
=
default_config
,
class
HaystackIterator
,
class
NeedlesIterator
,
class
OutputIterator
,
class
CompareFunction
=
::
rocprim
::
less
<
>
>
inline
cudaError_t
lower_bound
(
void
*
temporary_storage
,
size_t
&
storage_size
,
HaystackIterator
haystack
,
NeedlesIterator
needles
,
OutputIterator
output
,
size_t
haystack_size
,
size_t
needles_size
,
CompareFunction
compare_op
=
CompareFunction
(),
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
return
detail
::
binary_search
<
Config
>
(
temporary_storage
,
storage_size
,
haystack
,
needles
,
output
,
haystack_size
,
needles_size
,
detail
::
lower_bound_search_op
(),
compare_op
,
stream
,
debug_synchronous
);
}
template
<
class
Config
=
default_config
,
class
HaystackIterator
,
class
NeedlesIterator
,
class
OutputIterator
,
class
CompareFunction
=
::
rocprim
::
less
<
>
>
inline
cudaError_t
upper_bound
(
void
*
temporary_storage
,
size_t
&
storage_size
,
HaystackIterator
haystack
,
NeedlesIterator
needles
,
OutputIterator
output
,
size_t
haystack_size
,
size_t
needles_size
,
CompareFunction
compare_op
=
CompareFunction
(),
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
return
detail
::
binary_search
<
Config
>
(
temporary_storage
,
storage_size
,
haystack
,
needles
,
output
,
haystack_size
,
needles_size
,
detail
::
upper_bound_search_op
(),
compare_op
,
stream
,
debug_synchronous
);
}
template
<
class
Config
=
default_config
,
class
HaystackIterator
,
class
NeedlesIterator
,
class
OutputIterator
,
class
CompareFunction
=
::
rocprim
::
less
<
>
>
inline
cudaError_t
binary_search
(
void
*
temporary_storage
,
size_t
&
storage_size
,
HaystackIterator
haystack
,
NeedlesIterator
needles
,
OutputIterator
output
,
size_t
haystack_size
,
size_t
needles_size
,
CompareFunction
compare_op
=
CompareFunction
(),
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
return
detail
::
binary_search
<
Config
>
(
temporary_storage
,
storage_size
,
haystack
,
needles
,
output
,
haystack_size
,
needles_size
,
detail
::
binary_search_op
(),
compare_op
,
stream
,
debug_synchronous
);
}
/// @}
// end of group devicemodule
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HPP_
3rdparty/cub/rocprim/device/device_histogram.hpp
0 → 100644
View file @
f8a481f8
// Copyright (c) 2017-2019 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DEVICE_HISTOGRAM_HPP_
#define ROCPRIM_DEVICE_DEVICE_HISTOGRAM_HPP_
#include <cmath>
#include <type_traits>
#include <iterator>
#include <iostream>
#include <chrono>
#include "../config.hpp"
#include "../functional.hpp"
#include "../detail/various.hpp"
#include "device_histogram_config.hpp"
#include "detail/device_histogram.hpp"
BEGIN_ROCPRIM_NAMESPACE
/// \addtogroup devicemodule
/// @{
namespace
detail
{
template
<
unsigned
int
BlockSize
,
unsigned
int
ActiveChannels
,
class
Counter
>
ROCPRIM_KERNEL
__launch_bounds__
(
BlockSize
)
void
init_histogram_kernel
(
fixed_array
<
Counter
*
,
ActiveChannels
>
histogram
,
fixed_array
<
unsigned
int
,
ActiveChannels
>
bins
)
{
init_histogram
<
BlockSize
,
ActiveChannels
>
(
histogram
,
bins
);
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
unsigned
int
Channels
,
unsigned
int
ActiveChannels
,
class
SampleIterator
,
class
Counter
,
class
SampleToBinOp
>
ROCPRIM_KERNEL
__launch_bounds__
(
BlockSize
)
void
histogram_shared_kernel
(
SampleIterator
samples
,
unsigned
int
columns
,
unsigned
int
rows
,
unsigned
int
row_stride
,
unsigned
int
rows_per_block
,
fixed_array
<
Counter
*
,
ActiveChannels
>
histogram
,
fixed_array
<
SampleToBinOp
,
ActiveChannels
>
sample_to_bin_op
,
fixed_array
<
unsigned
int
,
ActiveChannels
>
bins
)
{
HIP_DYNAMIC_SHARED
(
unsigned
int
,
block_histogram
);
histogram_shared
<
BlockSize
,
ItemsPerThread
,
Channels
,
ActiveChannels
>
(
samples
,
columns
,
rows
,
row_stride
,
rows_per_block
,
histogram
,
sample_to_bin_op
,
bins
,
block_histogram
);
}
template
<
unsigned
int
BlockSize
,
unsigned
int
ItemsPerThread
,
unsigned
int
Channels
,
unsigned
int
ActiveChannels
,
class
SampleIterator
,
class
Counter
,
class
SampleToBinOp
>
ROCPRIM_KERNEL
__launch_bounds__
(
BlockSize
)
void
histogram_global_kernel
(
SampleIterator
samples
,
unsigned
int
columns
,
unsigned
int
row_stride
,
fixed_array
<
Counter
*
,
ActiveChannels
>
histogram
,
fixed_array
<
SampleToBinOp
,
ActiveChannels
>
sample_to_bin_op
,
fixed_array
<
unsigned
int
,
ActiveChannels
>
bins_bits
)
{
histogram_global
<
BlockSize
,
ItemsPerThread
,
Channels
,
ActiveChannels
>
(
samples
,
columns
,
row_stride
,
histogram
,
sample_to_bin_op
,
bins_bits
);
}
#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \
{ \
auto _error = cudaGetLastError(); \
if(_error != cudaSuccess) return _error; \
if(debug_synchronous) \
{ \
std::cout << name << "(" << size << ")"; \
auto __error = cudaStreamSynchronize(stream); \
if(__error != cudaSuccess) return __error; \
auto _end = std::chrono::high_resolution_clock::now(); \
auto _d = std::chrono::duration_cast<std::chrono::duration<double>>(_end - start); \
std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \
} \
}
template
<
unsigned
int
Channels
,
unsigned
int
ActiveChannels
,
class
Config
,
class
SampleIterator
,
class
Counter
,
class
SampleToBinOp
>
inline
cudaError_t
histogram_impl
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
columns
,
unsigned
int
rows
,
size_t
row_stride_bytes
,
Counter
*
histogram
[
ActiveChannels
],
unsigned
int
levels
[
ActiveChannels
],
SampleToBinOp
sample_to_bin_op
[
ActiveChannels
],
cudaStream_t
stream
,
bool
debug_synchronous
)
{
using
sample_type
=
typename
std
::
iterator_traits
<
SampleIterator
>::
value_type
;
using
config
=
default_or_custom_config
<
Config
,
default_histogram_config
<
ROCPRIM_TARGET_ARCH
,
sample_type
,
Channels
,
ActiveChannels
>
>
;
static
constexpr
unsigned
int
block_size
=
config
::
histogram
::
block_size
;
static
constexpr
unsigned
int
items_per_thread
=
config
::
histogram
::
items_per_thread
;
static
constexpr
unsigned
int
items_per_block
=
block_size
*
items_per_thread
;
if
(
row_stride_bytes
%
sizeof
(
sample_type
)
!=
0
)
{
// Row stride must be a whole multiple of the sample data type size
return
cudaErrorInvalidValue
;
}
const
unsigned
int
blocks_x
=
::
rocprim
::
detail
::
ceiling_div
(
columns
,
items_per_block
);
const
unsigned
int
row_stride
=
row_stride_bytes
/
sizeof
(
sample_type
);
if
(
temporary_storage
==
nullptr
)
{
// Make sure user won't try to allocate 0 bytes memory, because
// cudaMalloc will return nullptr.
storage_size
=
4
;
return
cudaSuccess
;
}
if
(
debug_synchronous
)
{
std
::
cout
<<
"columns "
<<
columns
<<
'\n'
;
std
::
cout
<<
"rows "
<<
rows
<<
'\n'
;
std
::
cout
<<
"blocks_x "
<<
blocks_x
<<
'\n'
;
cudaError_t
error
=
cudaStreamSynchronize
(
stream
);
if
(
error
!=
cudaSuccess
)
return
error
;
}
unsigned
int
bins
[
ActiveChannels
];
unsigned
int
bins_bits
[
ActiveChannels
];
unsigned
int
total_bins
=
0
;
unsigned
int
max_bins
=
0
;
for
(
unsigned
int
channel
=
0
;
channel
<
ActiveChannels
;
channel
++
)
{
bins
[
channel
]
=
levels
[
channel
]
-
1
;
bins_bits
[
channel
]
=
static_cast
<
unsigned
int
>
(
std
::
log2
(
detail
::
next_power_of_two
(
bins
[
channel
])));
total_bins
+=
bins
[
channel
];
max_bins
=
std
::
max
(
max_bins
,
bins
[
channel
]);
}
std
::
chrono
::
high_resolution_clock
::
time_point
start
;
if
(
debug_synchronous
)
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
init_histogram_kernel
<
block_size
,
ActiveChannels
><<<
dim3
(
::
rocprim
::
detail
::
ceiling_div
(
max_bins
,
block_size
)),
dim3
(
block_size
),
0
,
stream
>>>
(
fixed_array
<
Counter
*
,
ActiveChannels
>
(
histogram
),
fixed_array
<
unsigned
int
,
ActiveChannels
>
(
bins
)
);
ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR
(
"init_histogram"
,
max_bins
,
start
);
if
(
columns
==
0
||
rows
==
0
)
{
return
cudaSuccess
;
}
if
(
total_bins
<=
config
::
shared_impl_max_bins
)
{
dim3
grid_size
;
grid_size
.
x
=
std
::
min
(
config
::
max_grid_size
,
blocks_x
);
grid_size
.
y
=
std
::
min
(
rows
,
config
::
max_grid_size
/
grid_size
.
x
);
const
size_t
block_histogram_bytes
=
total_bins
*
sizeof
(
unsigned
int
);
const
unsigned
int
rows_per_block
=
::
rocprim
::
detail
::
ceiling_div
(
rows
,
grid_size
.
y
);
if
(
debug_synchronous
)
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
histogram_shared_kernel
<
block_size
,
items_per_thread
,
Channels
,
ActiveChannels
>
<<<
grid_size
,
dim3
(
block_size
,
1
),
block_histogram_bytes
,
stream
>>>
(
samples
,
columns
,
rows
,
row_stride
,
rows_per_block
,
fixed_array
<
Counter
*
,
ActiveChannels
>
(
histogram
),
fixed_array
<
SampleToBinOp
,
ActiveChannels
>
(
sample_to_bin_op
),
fixed_array
<
unsigned
int
,
ActiveChannels
>
(
bins
)
);
ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR
(
"histogram_shared"
,
grid_size
.
x
*
grid_size
.
y
*
block_size
,
start
);
}
else
{
if
(
debug_synchronous
)
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
histogram_global_kernel
<
block_size
,
items_per_thread
,
Channels
,
ActiveChannels
>
<<<
dim3
(
blocks_x
,
rows
),
dim3
(
block_size
,
1
),
0
,
stream
>>>
(
samples
,
columns
,
row_stride
,
fixed_array
<
Counter
*
,
ActiveChannels
>
(
histogram
),
fixed_array
<
SampleToBinOp
,
ActiveChannels
>
(
sample_to_bin_op
),
fixed_array
<
unsigned
int
,
ActiveChannels
>
(
bins_bits
)
);
ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR
(
"histogram_global"
,
blocks_x
*
block_size
*
rows
,
start
);
}
return
cudaSuccess
;
}
template
<
unsigned
int
Channels
,
unsigned
int
ActiveChannels
,
class
Config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
histogram_even_impl
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
columns
,
unsigned
int
rows
,
size_t
row_stride_bytes
,
Counter
*
histogram
[
ActiveChannels
],
unsigned
int
levels
[
ActiveChannels
],
Level
lower_level
[
ActiveChannels
],
Level
upper_level
[
ActiveChannels
],
cudaStream_t
stream
,
bool
debug_synchronous
)
{
for
(
unsigned
int
channel
=
0
;
channel
<
ActiveChannels
;
channel
++
)
{
if
(
levels
[
channel
]
<
2
)
{
// Histogram must have at least 1 bin
return
cudaErrorInvalidValue
;
}
}
sample_to_bin_even
<
Level
>
sample_to_bin_op
[
ActiveChannels
];
for
(
unsigned
int
channel
=
0
;
channel
<
ActiveChannels
;
channel
++
)
{
sample_to_bin_op
[
channel
]
=
sample_to_bin_even
<
Level
>
(
levels
[
channel
]
-
1
,
lower_level
[
channel
],
upper_level
[
channel
]
);
}
return
histogram_impl
<
Channels
,
ActiveChannels
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
columns
,
rows
,
row_stride_bytes
,
histogram
,
levels
,
sample_to_bin_op
,
stream
,
debug_synchronous
);
}
template
<
unsigned
int
Channels
,
unsigned
int
ActiveChannels
,
class
Config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
histogram_range_impl
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
columns
,
unsigned
int
rows
,
size_t
row_stride_bytes
,
Counter
*
histogram
[
ActiveChannels
],
unsigned
int
levels
[
ActiveChannels
],
Level
*
level_values
[
ActiveChannels
],
cudaStream_t
stream
,
bool
debug_synchronous
)
{
for
(
unsigned
int
channel
=
0
;
channel
<
ActiveChannels
;
channel
++
)
{
if
(
levels
[
channel
]
<
2
)
{
// Histogram must have at least 1 bin
return
cudaErrorInvalidValue
;
}
}
sample_to_bin_range
<
Level
>
sample_to_bin_op
[
ActiveChannels
];
for
(
unsigned
int
channel
=
0
;
channel
<
ActiveChannels
;
channel
++
)
{
sample_to_bin_op
[
channel
]
=
sample_to_bin_range
<
Level
>
(
levels
[
channel
]
-
1
,
level_values
[
channel
]
);
}
return
histogram_impl
<
Channels
,
ActiveChannels
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
columns
,
rows
,
row_stride_bytes
,
histogram
,
levels
,
sample_to_bin_op
,
stream
,
debug_synchronous
);
}
#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR
}
// end of detail namespace
/// \brief Computes a histogram from a sequence of samples using equal-width bins.
///
/// \par
/// * The number of histogram bins is (\p levels - 1).
/// * Bins are evenly-segmented and include the same width of sample values:
/// (\p upper_level - \p lower_level) / (\p levels - 1).
/// * Returns the required size of \p temporary_storage in \p storage_size
/// if \p temporary_storage in a null pointer.
///
/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or
/// a custom class with the same members.
/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam Counter - integer type for histogram bin counters.
/// \tparam Level - type of histogram boundaries (levels)
///
/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// \p storage_size and function returns without performing the reduction operation.
/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage.
/// \param [in] samples - iterator to the first element in the range of input samples.
/// \param [in] size - number of elements in the samples range.
/// \param [out] histogram - pointer to the first element in the histogram range.
/// \param [in] levels - number of boundaries (levels) for histogram bins.
/// \param [in] lower_level - lower sample value bound (inclusive) for the first histogram bin.
/// \param [in] upper_level - upper sample value bound (exclusive) for the last histogram bin.
/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream).
/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors. Default value is \p false.
///
/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of
/// type \p cudaError_t.
///
/// \par Example
/// \parblock
/// In this example a device-level histogram of 5 bins is computed on an array of float samples.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp>
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// unsigned int size; // e.g., 8
/// float * samples; // e.g., [-10.0, 0.3, 9.5, 8.1, 1.5, 1.9, 100.0, 5.1]
/// int * histogram; // empty array of at least 5 elements
/// unsigned int levels; // e.g., 6 (for 5 bins)
/// float lower_level; // e.g., 0.0
/// float upper_level; // e.g., 10.0
///
/// size_t temporary_storage_size_bytes;
/// void * temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::histogram_even(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, size,
/// histogram, levels, lower_level, upper_level
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // compute histogram
/// rocprim::histogram_even(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, size,
/// histogram, levels, lower_level, upper_level
/// );
/// // histogram: [3, 0, 1, 0, 2]
/// \endcode
/// \endparblock
template
<
class
Config
=
default_config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
histogram_even
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
size
,
Counter
*
histogram
,
unsigned
int
levels
,
Level
lower_level
,
Level
upper_level
,
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
Counter
*
histogram_single
[
1
]
=
{
histogram
};
unsigned
int
levels_single
[
1
]
=
{
levels
};
Level
lower_level_single
[
1
]
=
{
lower_level
};
Level
upper_level_single
[
1
]
=
{
upper_level
};
return
detail
::
histogram_even_impl
<
1
,
1
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
size
,
1
,
0
,
histogram_single
,
levels_single
,
lower_level_single
,
upper_level_single
,
stream
,
debug_synchronous
);
}
/// \brief Computes a histogram from a two-dimensional region of samples using equal-width bins.
///
/// \par
/// * The two-dimensional region of interest within \p samples can be specified using the \p columns,
/// \p rows and \p row_stride_bytes parameters.
/// * The row stride must be a whole multiple of the sample data type size,
/// i.e., <tt>(row_stride_bytes % sizeof(std::iterator_traits<SampleIterator>::value_type)) == 0</tt>.
/// * The number of histogram bins is (\p levels - 1).
/// * Bins are evenly-segmented and include the same width of sample values:
/// (\p upper_level - \p lower_level) / (\p levels - 1).
/// * Returns the required size of \p temporary_storage in \p storage_size
/// if \p temporary_storage in a null pointer.
///
/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or
/// a custom class with the same members.
/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam Counter - integer type for histogram bin counters.
/// \tparam Level - type of histogram boundaries (levels)
///
/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// \p storage_size and function returns without performing the reduction operation.
/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage.
/// \param [in] samples - iterator to the first element in the range of input samples.
/// \param [in] columns - number of elements in each row of the region.
/// \param [in] rows - number of rows of the region.
/// \param [in] row_stride_bytes - number of bytes between starts of consecutive rows of the region.
/// \param [out] histogram - pointer to the first element in the histogram range.
/// \param [in] levels - number of boundaries (levels) for histogram bins.
/// \param [in] lower_level - lower sample value bound (inclusive) for the first histogram bin.
/// \param [in] upper_level - upper sample value bound (exclusive) for the last histogram bin.
/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream).
/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors. Default value is \p false.
///
/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of
/// type \p cudaError_t.
///
/// \par Example
/// \parblock
/// In this example a device-level histogram of 5 bins is computed on an array of float samples.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp>
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// unsigned int columns; // e.g., 4
/// unsigned int rows; // e.g., 2
/// size_t row_stride_bytes; // e.g., 6 * sizeof(float)
/// float * samples; // e.g., [-10.0, 0.3, 9.5, 8.1, -, -, 1.5, 1.9, 100.0, 5.1, -, -]
/// int * histogram; // empty array of at least 5 elements
/// unsigned int levels; // e.g., 6 (for 5 bins)
/// float lower_level; // e.g., 0.0
/// float upper_level; // e.g., 10.0
///
/// size_t temporary_storage_size_bytes;
/// void * temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::histogram_even(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, columns, rows, row_stride_bytes,
/// histogram, levels, lower_level, upper_level
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // compute histogram
/// rocprim::histogram_even(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, columns, rows, row_stride_bytes,
/// histogram, levels, lower_level, upper_level
/// );
/// // histogram: [3, 0, 1, 0, 2]
/// \endcode
/// \endparblock
template
<
class
Config
=
default_config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
histogram_even
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
columns
,
unsigned
int
rows
,
size_t
row_stride_bytes
,
Counter
*
histogram
,
unsigned
int
levels
,
Level
lower_level
,
Level
upper_level
,
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
Counter
*
histogram_single
[
1
]
=
{
histogram
};
unsigned
int
levels_single
[
1
]
=
{
levels
};
Level
lower_level_single
[
1
]
=
{
lower_level
};
Level
upper_level_single
[
1
]
=
{
upper_level
};
return
detail
::
histogram_even_impl
<
1
,
1
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
columns
,
rows
,
row_stride_bytes
,
histogram_single
,
levels_single
,
lower_level_single
,
upper_level_single
,
stream
,
debug_synchronous
);
}
/// \brief Computes histograms from a sequence of multi-channel samples using equal-width bins.
///
/// \par
/// * The input is a sequence of <em>pixel</em> structures, where each pixel comprises
/// a record of \p Channels consecutive data samples (e.g., \p Channels = 4 for <em>RGBA</em> samples).
/// * The first \p ActiveChannels channels of total \p Channels channels will be used for computing histograms
/// (e.g., \p ActiveChannels = 3 for computing histograms of only <em>RGB</em> from <em>RGBA</em> samples).
/// * For channel<sub><em>i</em></sub> the number of histogram bins is (\p levels[i] - 1).
/// * For channel<sub><em>i</em></sub> bins are evenly-segmented and include the same width of sample values:
/// (\p upper_level[i] - \p lower_level[i]) / (\p levels[i] - 1).
/// * Returns the required size of \p temporary_storage in \p storage_size
/// if \p temporary_storage in a null pointer.
///
/// \tparam Channels - number of channels interleaved in the input samples.
/// \tparam ActiveChannels - number of channels being used for computing histograms.
/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or
/// a custom class with the same members.
/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam Counter - integer type for histogram bin counters.
/// \tparam Level - type of histogram boundaries (levels)
///
/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// \p storage_size and function returns without performing the reduction operation.
/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage.
/// \param [in] samples - iterator to the first element in the range of input samples.
/// \param [in] size - number of pixels in the samples range.
/// \param [out] histogram - pointers to the first element in the histogram range, one for each active channel.
/// \param [in] levels - number of boundaries (levels) for histogram bins in each active channel.
/// \param [in] lower_level - lower sample value bound (inclusive) for the first histogram bin in each active channel.
/// \param [in] upper_level - upper sample value bound (exclusive) for the last histogram bin in each active channel.
/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream).
/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors. Default value is \p false.
///
/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of
/// type \p cudaError_t.
///
/// \par Example
/// \parblock
/// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp>
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// unsigned int size; // e.g., 8
/// unsigned char * samples; // e.g., [(3, 1, 5, 255), (3, 1, 5, 255), (4, 2, 6, 127), (3, 2, 6, 127),
/// // (0, 0, 0, 100), (0, 1, 0, 100), (0, 0, 1, 255), (0, 1, 1, 255)]
/// int * histogram[3]; // 3 empty arrays of at least 256 elements each
/// unsigned int levels[3]; // e.g., [257, 257, 257] (for 256 bins)
/// int lower_level[3]; // e.g., [0, 0, 0]
/// int upper_level[3]; // e.g., [256, 256, 256]
///
/// size_t temporary_storage_size_bytes;
/// void * temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::multi_histogram_even<4, 3>(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, size,
/// histogram, levels, lower_level, upper_level
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // compute histograms
/// rocprim::multi_histogram_even<4, 3>(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, size,
/// histogram, levels, lower_level, upper_level
/// );
/// // histogram: [[4, 0, 0, 3, 1, 0, 0, ..., 0],
/// // [2, 4, 2, 0, 0, 0, 0, ..., 0],
/// // [2, 2, 0, 0, 0, 2, 2, ..., 0]]
/// \endcode
/// \endparblock
template
<
unsigned
int
Channels
,
unsigned
int
ActiveChannels
,
class
Config
=
default_config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
multi_histogram_even
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
size
,
Counter
*
histogram
[
ActiveChannels
],
unsigned
int
levels
[
ActiveChannels
],
Level
lower_level
[
ActiveChannels
],
Level
upper_level
[
ActiveChannels
],
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
return
detail
::
histogram_even_impl
<
Channels
,
ActiveChannels
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
size
,
1
,
0
,
histogram
,
levels
,
lower_level
,
upper_level
,
stream
,
debug_synchronous
);
}
/// \brief Computes histograms from a two-dimensional region of multi-channel samples using equal-width bins.
///
/// \par
/// * The two-dimensional region of interest within \p samples can be specified using the \p columns,
/// \p rows and \p row_stride_bytes parameters.
/// * The row stride must be a whole multiple of the sample data type size,
/// i.e., <tt>(row_stride_bytes % sizeof(std::iterator_traits<SampleIterator>::value_type)) == 0</tt>.
/// * The input is a sequence of <em>pixel</em> structures, where each pixel comprises
/// a record of \p Channels consecutive data samples (e.g., \p Channels = 4 for <em>RGBA</em> samples).
/// * The first \p ActiveChannels channels of total \p Channels channels will be used for computing histograms
/// (e.g., \p ActiveChannels = 3 for computing histograms of only <em>RGB</em> from <em>RGBA</em> samples).
/// * For channel<sub><em>i</em></sub> the number of histogram bins is (\p levels[i] - 1).
/// * For channel<sub><em>i</em></sub> bins are evenly-segmented and include the same width of sample values:
/// (\p upper_level[i] - \p lower_level[i]) / (\p levels[i] - 1).
/// * Returns the required size of \p temporary_storage in \p storage_size
/// if \p temporary_storage in a null pointer.
///
/// \tparam Channels - number of channels interleaved in the input samples.
/// \tparam ActiveChannels - number of channels being used for computing histograms.
/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or
/// a custom class with the same members.
/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam Counter - integer type for histogram bin counters.
/// \tparam Level - type of histogram boundaries (levels)
///
/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// \p storage_size and function returns without performing the reduction operation.
/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage.
/// \param [in] samples - iterator to the first element in the range of input samples.
/// \param [in] columns - number of elements in each row of the region.
/// \param [in] rows - number of rows of the region.
/// \param [in] row_stride_bytes - number of bytes between starts of consecutive rows of the region.
/// \param [out] histogram - pointers to the first element in the histogram range, one for each active channel.
/// \param [in] levels - number of boundaries (levels) for histogram bins in each active channel.
/// \param [in] lower_level - lower sample value bound (inclusive) for the first histogram bin in each active channel.
/// \param [in] upper_level - upper sample value bound (exclusive) for the last histogram bin in each active channel.
/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream).
/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors. Default value is \p false.
///
/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of
/// type \p cudaError_t.
///
/// \par Example
/// \parblock
/// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp>
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// unsigned int columns; // e.g., 4
/// unsigned int rows; // e.g., 2
/// size_t row_stride_bytes; // e.g., 5 * sizeof(unsigned char)
/// unsigned char * samples; // e.g., [(3, 1, 5, 255), (3, 1, 5, 255), (4, 2, 6, 127), (3, 2, 6, 127), (-, -, -, -),
/// // (0, 0, 0, 100), (0, 1, 0, 100), (0, 0, 1, 255), (0, 1, 1, 255), (-, -, -, -)]
/// int * histogram[3]; // 3 empty arrays of at least 256 elements each
/// unsigned int levels[3]; // e.g., [257, 257, 257] (for 256 bins)
/// int lower_level[3]; // e.g., [0, 0, 0]
/// int upper_level[3]; // e.g., [256, 256, 256]
///
/// size_t temporary_storage_size_bytes;
/// void * temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::multi_histogram_even<4, 3>(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, columns, rows, row_stride_bytes,
/// histogram, levels, lower_level, upper_level
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // compute histograms
/// rocprim::multi_histogram_even<4, 3>(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, columns, rows, row_stride_bytes,
/// histogram, levels, lower_level, upper_level
/// );
/// // histogram: [[4, 0, 0, 3, 1, 0, 0, ..., 0],
/// // [2, 4, 2, 0, 0, 0, 0, ..., 0],
/// // [2, 2, 0, 0, 0, 2, 2, ..., 0]]
/// \endcode
/// \endparblock
template
<
unsigned
int
Channels
,
unsigned
int
ActiveChannels
,
class
Config
=
default_config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
multi_histogram_even
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
columns
,
unsigned
int
rows
,
size_t
row_stride_bytes
,
Counter
*
histogram
[
ActiveChannels
],
unsigned
int
levels
[
ActiveChannels
],
Level
lower_level
[
ActiveChannels
],
Level
upper_level
[
ActiveChannels
],
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
return
detail
::
histogram_even_impl
<
Channels
,
ActiveChannels
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
columns
,
rows
,
row_stride_bytes
,
histogram
,
levels
,
lower_level
,
upper_level
,
stream
,
debug_synchronous
);
}
/// \brief Computes a histogram from a sequence of samples using the specified bin boundary levels.
///
/// \par
/// * The number of histogram bins is (\p levels - 1).
/// * The range for bin<sub><em>j</em></sub> is [<tt>level_values[j]</tt>, <tt>level_values[j+1]</tt>).
/// * Returns the required size of \p temporary_storage in \p storage_size
/// if \p temporary_storage in a null pointer.
///
/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or
/// a custom class with the same members.
/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam Counter - integer type for histogram bin counters.
/// \tparam Level - type of histogram boundaries (levels)
///
/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// \p storage_size and function returns without performing the reduction operation.
/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage.
/// \param [in] samples - iterator to the first element in the range of input samples.
/// \param [in] size - number of elements in the samples range.
/// \param [out] histogram - pointer to the first element in the histogram range.
/// \param [in] levels - number of boundaries (levels) for histogram bins.
/// \param [in] level_values - pointer to the array of bin boundaries.
/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream).
/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors. Default value is \p false.
///
/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of
/// type \p cudaError_t.
///
/// \par Example
/// \parblock
/// In this example a device-level histogram of 5 bins is computed on an array of float samples.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp>
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// unsigned int size; // e.g., 8
/// float * samples; // e.g., [-10.0, 0.3, 9.5, 8.1, 1.5, 1.9, 100.0, 5.1]
/// int * histogram; // empty array of at least 5 elements
/// unsigned int levels; // e.g., 6 (for 5 bins)
/// float * level_values; // e.g., [0.0, 1.0, 5.0, 10.0, 20.0, 50.0]
///
/// size_t temporary_storage_size_bytes;
/// void * temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::histogram_range(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, size,
/// histogram, levels, level_values
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // compute histogram
/// rocprim::histogram_range(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, size,
/// histogram, levels, level_values
/// );
/// // histogram: [1, 2, 3, 0, 0]
/// \endcode
/// \endparblock
template
<
class
Config
=
default_config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
histogram_range
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
size
,
Counter
*
histogram
,
unsigned
int
levels
,
Level
*
level_values
,
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
Counter
*
histogram_single
[
1
]
=
{
histogram
};
unsigned
int
levels_single
[
1
]
=
{
levels
};
Level
*
level_values_single
[
1
]
=
{
level_values
};
return
detail
::
histogram_range_impl
<
1
,
1
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
size
,
1
,
0
,
histogram_single
,
levels_single
,
level_values_single
,
stream
,
debug_synchronous
);
}
/// \brief Computes a histogram from a two-dimensional region of samples using the specified bin boundary levels.
///
/// \par
/// * The two-dimensional region of interest within \p samples can be specified using the \p columns,
/// \p rows and \p row_stride_bytes parameters.
/// * The row stride must be a whole multiple of the sample data type size,
/// i.e., <tt>(row_stride_bytes % sizeof(std::iterator_traits<SampleIterator>::value_type)) == 0</tt>.
/// * The number of histogram bins is (\p levels - 1).
/// * The range for bin<sub><em>j</em></sub> is [<tt>level_values[j]</tt>, <tt>level_values[j+1]</tt>).
/// * Returns the required size of \p temporary_storage in \p storage_size
/// if \p temporary_storage in a null pointer.
///
/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or
/// a custom class with the same members.
/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam Counter - integer type for histogram bin counters.
/// \tparam Level - type of histogram boundaries (levels)
///
/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// \p storage_size and function returns without performing the reduction operation.
/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage.
/// \param [in] samples - iterator to the first element in the range of input samples.
/// \param [in] columns - number of elements in each row of the region.
/// \param [in] rows - number of rows of the region.
/// \param [in] row_stride_bytes - number of bytes between starts of consecutive rows of the region.
/// \param [out] histogram - pointer to the first element in the histogram range.
/// \param [in] levels - number of boundaries (levels) for histogram bins.
/// \param [in] level_values - pointer to the array of bin boundaries.
/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream).
/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors. Default value is \p false.
///
/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of
/// type \p cudaError_t.
///
/// \par Example
/// \parblock
/// In this example a device-level histogram of 5 bins is computed on an array of float samples.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp>
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// unsigned int columns; // e.g., 4
/// unsigned int rows; // e.g., 2
/// size_t row_stride_bytes; // e.g., 6 * sizeof(float)
/// float * samples; // e.g., [-10.0, 0.3, 9.5, 8.1, 1.5, 1.9, 100.0, 5.1]
/// int * histogram; // empty array of at least 5 elements
/// unsigned int levels; // e.g., 6 (for 5 bins)
/// float level_values; // e.g., [0.0, 1.0, 5.0, 10.0, 20.0, 50.0]
///
/// size_t temporary_storage_size_bytes;
/// void * temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::histogram_range(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, columns, rows, row_stride_bytes,
/// histogram, levels, level_values
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // compute histogram
/// rocprim::histogram_range(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, columns, rows, row_stride_bytes,
/// histogram, levels, level_values
/// );
/// // histogram: [1, 2, 3, 0, 0]
/// \endcode
/// \endparblock
template
<
class
Config
=
default_config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
histogram_range
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
columns
,
unsigned
int
rows
,
size_t
row_stride_bytes
,
Counter
*
histogram
,
unsigned
int
levels
,
Level
*
level_values
,
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
Counter
*
histogram_single
[
1
]
=
{
histogram
};
unsigned
int
levels_single
[
1
]
=
{
levels
};
Level
*
level_values_single
[
1
]
=
{
level_values
};
return
detail
::
histogram_range_impl
<
1
,
1
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
columns
,
rows
,
row_stride_bytes
,
histogram_single
,
levels_single
,
level_values_single
,
stream
,
debug_synchronous
);
}
/// \brief Computes histograms from a sequence of multi-channel samples using the specified bin boundary levels.
///
/// \par
/// * The input is a sequence of <em>pixel</em> structures, where each pixel comprises
/// a record of \p Channels consecutive data samples (e.g., \p Channels = 4 for <em>RGBA</em> samples).
/// * The first \p ActiveChannels channels of total \p Channels channels will be used for computing histograms
/// (e.g., \p ActiveChannels = 3 for computing histograms of only <em>RGB</em> from <em>RGBA</em> samples).
/// * For channel<sub><em>i</em></sub> the number of histogram bins is (\p levels[i] - 1).
/// * For channel<sub><em>i</em></sub> the range for bin<sub><em>j</em></sub> is
/// [<tt>level_values[i][j]</tt>, <tt>level_values[i][j+1]</tt>).
/// * Returns the required size of \p temporary_storage in \p storage_size
/// if \p temporary_storage in a null pointer.
///
/// \tparam Channels - number of channels interleaved in the input samples.
/// \tparam ActiveChannels - number of channels being used for computing histograms.
/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or
/// a custom class with the same members.
/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam Counter - integer type for histogram bin counters.
/// \tparam Level - type of histogram boundaries (levels)
///
/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// \p storage_size and function returns without performing the reduction operation.
/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage.
/// \param [in] samples - iterator to the first element in the range of input samples.
/// \param [in] size - number of pixels in the samples range.
/// \param [out] histogram - pointers to the first element in the histogram range, one for each active channel.
/// \param [in] levels - number of boundaries (levels) for histogram bins in each active channel.
/// \param [in] level_values - pointer to the array of bin boundaries for each active channel.
/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream).
/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors. Default value is \p false.
///
/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of
/// type \p cudaError_t.
///
/// \par Example
/// \parblock
/// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp>
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// unsigned int size; // e.g., 8
/// unsigned char * samples; // e.g., [(0, 0, 80, 255), (120, 0, 80, 255), (123, 0, 82, 127), (10, 1, 83, 127),
/// // (51, 1, 8, 100), (52, 1, 8, 100), (53, 0, 81, 255), (54, 50, 81, 255)]
/// int * histogram[3]; // 3 empty arrays of at least 256 elements each
/// unsigned int levels[3]; // e.g., [4, 4, 3]
/// int * level_values[3]; // e.g., [[0, 50, 100, 200], [0, 20, 40, 60], [0, 10, 100]]
///
/// size_t temporary_storage_size_bytes;
/// void * temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::multi_histogram_range<4, 3>(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, size,
/// histogram, levels, level_values
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // compute histograms
/// rocprim::multi_histogram_range<4, 3>(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, size,
/// histogram, levels, level_values
/// );
/// // histogram: [[2, 4, 2], [7, 0, 1], [2, 6]]
/// \endcode
/// \endparblock
template
<
unsigned
int
Channels
,
unsigned
int
ActiveChannels
,
class
Config
=
default_config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
multi_histogram_range
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
size
,
Counter
*
histogram
[
ActiveChannels
],
unsigned
int
levels
[
ActiveChannels
],
Level
*
level_values
[
ActiveChannels
],
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
return
detail
::
histogram_range_impl
<
Channels
,
ActiveChannels
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
size
,
1
,
0
,
histogram
,
levels
,
level_values
,
stream
,
debug_synchronous
);
}
/// \brief Computes histograms from a two-dimensional region of multi-channel samples using the specified bin
/// boundary levels.
///
/// \par
/// * The two-dimensional region of interest within \p samples can be specified using the \p columns,
/// \p rows and \p row_stride_bytes parameters.
/// * The row stride must be a whole multiple of the sample data type size,
/// i.e., <tt>(row_stride_bytes % sizeof(std::iterator_traits<SampleIterator>::value_type)) == 0</tt>.
/// * The input is a sequence of <em>pixel</em> structures, where each pixel comprises
/// a record of \p Channels consecutive data samples (e.g., \p Channels = 4 for <em>RGBA</em> samples).
/// * The first \p ActiveChannels channels of total \p Channels channels will be used for computing histograms
/// (e.g., \p ActiveChannels = 3 for computing histograms of only <em>RGB</em> from <em>RGBA</em> samples).
/// * For channel<sub><em>i</em></sub> the number of histogram bins is (\p levels[i] - 1).
/// * For channel<sub><em>i</em></sub> the range for bin<sub><em>j</em></sub> is
/// [<tt>level_values[i][j]</tt>, <tt>level_values[i][j+1]</tt>).
/// * Returns the required size of \p temporary_storage in \p storage_size
/// if \p temporary_storage in a null pointer.
///
/// \tparam Channels - number of channels interleaved in the input samples.
/// \tparam ActiveChannels - number of channels being used for computing histograms.
/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or
/// a custom class with the same members.
/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the
/// requirements of a C++ InputIterator concept. It can be a simple pointer type.
/// \tparam Counter - integer type for histogram bin counters.
/// \tparam Level - type of histogram boundaries (levels)
///
/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When
/// a null pointer is passed, the required allocation size (in bytes) is written to
/// \p storage_size and function returns without performing the reduction operation.
/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage.
/// \param [in] samples - iterator to the first element in the range of input samples.
/// \param [in] columns - number of elements in each row of the region.
/// \param [in] rows - number of rows of the region.
/// \param [in] row_stride_bytes - number of bytes between starts of consecutive rows of the region.
/// \param [out] histogram - pointers to the first element in the histogram range, one for each active channel.
/// \param [in] levels - number of boundaries (levels) for histogram bins in each active channel.
/// \param [in] level_values - pointer to the array of bin boundaries for each active channel.
/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream).
/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel
/// launch is forced in order to check for errors. Default value is \p false.
///
/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of
/// type \p cudaError_t.
///
/// \par Example
/// \parblock
/// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples.
///
/// \code{.cpp}
/// #include <rocprim/rocprim.hpp>
///
/// // Prepare input and output (declare pointers, allocate device memory etc.)
/// unsigned int columns; // e.g., 4
/// unsigned int rows; // e.g., 2
/// size_t row_stride_bytes; // e.g., 5 * sizeof(unsigned char)
/// unsigned char * samples; // e.g., [(0, 0, 80, 0), (120, 0, 80, 0), (123, 0, 82, 0), (10, 1, 83, 0), (-, -, -, -),
/// // (51, 1, 8, 0), (52, 1, 8, 0), (53, 0, 81, 0), (54, 50, 81, 0), (-, -, -, -)]
/// int * histogram[3]; // 3 empty arrays
/// unsigned int levels[3]; // e.g., [4, 4, 3]
/// int * level_values[3]; // e.g., [[0, 50, 100, 200], [0, 20, 40, 60], [0, 10, 100]]
///
/// size_t temporary_storage_size_bytes;
/// void * temporary_storage_ptr = nullptr;
/// // Get required size of the temporary storage
/// rocprim::multi_histogram_range<4, 3>(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, columns, rows, row_stride_bytes,
/// histogram, levels, level_values
/// );
///
/// // allocate temporary storage
/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes);
///
/// // compute histograms
/// rocprim::multi_histogram_range<4, 3>(
/// temporary_storage_ptr, temporary_storage_size_bytes,
/// samples, columns, rows, row_stride_bytes,
/// histogram, levels, level_values
/// );
/// // histogram: [[2, 4, 2], [7, 0, 1], [2, 6]]
/// \endcode
/// \endparblock
template
<
unsigned
int
Channels
,
unsigned
int
ActiveChannels
,
class
Config
=
default_config
,
class
SampleIterator
,
class
Counter
,
class
Level
>
inline
cudaError_t
multi_histogram_range
(
void
*
temporary_storage
,
size_t
&
storage_size
,
SampleIterator
samples
,
unsigned
int
columns
,
unsigned
int
rows
,
size_t
row_stride_bytes
,
Counter
*
histogram
[
ActiveChannels
],
unsigned
int
levels
[
ActiveChannels
],
Level
*
level_values
[
ActiveChannels
],
cudaStream_t
stream
=
0
,
bool
debug_synchronous
=
false
)
{
return
detail
::
histogram_range_impl
<
Channels
,
ActiveChannels
,
Config
>
(
temporary_storage
,
storage_size
,
samples
,
columns
,
rows
,
row_stride_bytes
,
histogram
,
levels
,
level_values
,
stream
,
debug_synchronous
);
}
/// @}
// end of group devicemodule
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DEVICE_HISTOGRAM_HPP_
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment