Unverified Commit 7c24654c authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Fix incomplete object size (=4n + 3) support of amd_wave_read_first_lane() (#738)

* Fix wrong pointer type

* Rename type trait get_unsigned_int<> to get_carrier<>

* Add 3-bytes carrier type

* Add missing __device__ specifier

* Rename template non-type parameter

* Leave the rest byte uninitialized

* Avoid invoking (host) STL algorithms

* Remove unnecessary 'inline' specifier

* Extract common logic out as helper method

* Hide dummy member function

* Add missing __device__ specifier
parent 0ede66de
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/utility/functional2.hpp" #include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include <array>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
...@@ -14,29 +15,83 @@ ...@@ -14,29 +15,83 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <unsigned Size> template <unsigned SizeInBytes>
struct get_unsigned_int; struct get_carrier;
template <> template <>
struct get_unsigned_int<1> struct get_carrier<1>
{ {
using type = uint8_t; using type = uint8_t;
}; };
template <> template <>
struct get_unsigned_int<2> struct get_carrier<2>
{ {
using type = uint16_t; using type = uint16_t;
}; };
template <> template <>
struct get_unsigned_int<4> struct get_carrier<3>
{
using type = class carrier
{
using value_type = uint32_t;
std::array<std::byte, 3> bytes;
static_assert(sizeof(bytes) <= sizeof(value_type));
// replacement of host std::copy_n()
template <typename InputIterator, typename Size, typename OutputIterator>
__device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to)
{
if(0 < size)
{
*to = *from;
++to;
for(Size count = 1; count < size; ++count)
{
*to = *++from;
++to;
}
}
return to;
}
// method to trigger template substitution failure
__device__ carrier(const carrier& other) noexcept
{
copy_n(other.bytes.begin(), bytes.size(), bytes.begin());
}
public:
__device__ carrier& operator=(value_type value) noexcept
{
copy_n(reinterpret_cast<const std::byte*>(&value), bytes.size(), bytes.begin());
return *this;
}
__device__ operator value_type() const noexcept
{
std::byte result[sizeof(value_type)];
copy_n(bytes.begin(), bytes.size(), result);
return *reinterpret_cast<const value_type*>(result);
}
};
};
static_assert(sizeof(get_carrier<3>::type) == 3);
template <>
struct get_carrier<4>
{ {
using type = uint32_t; using type = uint32_t;
}; };
template <unsigned Size> template <unsigned SizeInBytes>
using get_unsigned_int_t = typename get_unsigned_int<Size>::type; using get_carrier_t = typename get_carrier<SizeInBytes>::type;
} // namespace detail } // namespace detail
...@@ -61,7 +116,7 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) ...@@ -61,7 +116,7 @@ __device__ auto amd_wave_read_first_lane(const Object& obj)
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize) for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize)
{ {
using Sgpr = detail::get_unsigned_int_t<SgprSize>; using Sgpr = detail::get_carrier_t<SgprSize>;
*reinterpret_cast<Sgpr*>(to_obj + offset) = *reinterpret_cast<Sgpr*>(to_obj + offset) =
amd_wave_read_first_lane(*reinterpret_cast<const Sgpr*>(from_obj + offset)); amd_wave_read_first_lane(*reinterpret_cast<const Sgpr*>(from_obj + offset));
...@@ -69,9 +124,9 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) ...@@ -69,9 +124,9 @@ __device__ auto amd_wave_read_first_lane(const Object& obj)
if constexpr(0 < RemainedSize) if constexpr(0 < RemainedSize)
{ {
using Carrier = detail::get_unsigned_int_t<RemainedSize>; using Carrier = detail::get_carrier_t<RemainedSize>;
*reinterpret_cast<Carrier>(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( *reinterpret_cast<Carrier*>(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane(
*reinterpret_cast<const Carrier*>(from_obj + CompleteSgprCopyBoundary)); *reinterpret_cast<const Carrier*>(from_obj + CompleteSgprCopyBoundary));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment