Commit ad65dfe7 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 2a27d15c 15baccf2
...@@ -991,7 +991,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -991,7 +991,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
asm volatile("s_mov_b32 m0, %0; \n\t" asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes), "v"(global_offset_bytes),
"s"(src_resource)); "s"(src_resource)
: "memory");
#else #else
// LDS pointer must be attributed with the LDS address space. // LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr = __attribute__((address_space(3))) uint32_t* lds_ptr =
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#pragma once
namespace ck {
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_16x16x32f16;
template <>
struct intrin_smfmac_f32_16x16x32f16<16, 16>
{
template <class FloatC>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_16x16x32bf16;
template <>
struct intrin_smfmac_f32_16x16x32bf16<16, 16>
{
template <class FloatC>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_32x32x16f16;
template <>
struct intrin_smfmac_f32_32x32x16f16<32, 32>
{
template <class FloatC>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_32x32x16bf16;
template <>
struct intrin_smfmac_f32_32x32x16bf16<32, 32>
{
template <class FloatC>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -95,11 +95,33 @@ using get_carrier_t = typename get_carrier<SizeInBytes>::type; ...@@ -95,11 +95,33 @@ using get_carrier_t = typename get_carrier<SizeInBytes>::type;
} // namespace detail } // namespace detail
__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}
__device__ inline int32_t amd_wave_read_first_lane(int32_t value) __device__ inline int32_t amd_wave_read_first_lane(int32_t value)
{ {
return __builtin_amdgcn_readfirstlane(value); return __builtin_amdgcn_readfirstlane(value);
} }
__device__ inline int64_t amd_wave_read_first_lane(int64_t value)
{
constexpr unsigned object_size = sizeof(int64_t);
constexpr unsigned second_part_offset = object_size / 2;
auto* const from_obj = reinterpret_cast<const std::byte*>(&value);
alignas(int64_t) std::byte to_obj[object_size];
using Sgpr = uint32_t;
*reinterpret_cast<Sgpr*>(to_obj) =
amd_wave_read_first_lane(*reinterpret_cast<const Sgpr*>(from_obj));
*reinterpret_cast<Sgpr*>(to_obj + second_part_offset) =
amd_wave_read_first_lane(*reinterpret_cast<const Sgpr*>(from_obj + second_part_offset));
return *reinterpret_cast<int64_t*>(to_obj);
}
template < template <
typename Object, typename Object,
typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>> typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>>
......
...@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
} }
}; };
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
template <>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16>
{
template <class FloatC>
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx12__)
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx12__)
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx12__)
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
neg_a,
bit_cast<int32x2_t>(reg_a),
neg_b,
bit_cast<int32x2_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -36,6 +36,8 @@ struct Array ...@@ -36,6 +36,8 @@ struct Array
return *this; return *this;
} }
__host__ __device__ constexpr const TData* begin() const { return &mData[0]; }
__host__ __device__ constexpr const TData* end() const { return &mData[NSize]; }
}; };
// empty Array // empty Array
......
...@@ -203,7 +203,7 @@ struct vector_type<T, 1> ...@@ -203,7 +203,7 @@ struct vector_type<T, 1>
} }
}; };
int static err = 0; __device__ int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2>
{ {
......
...@@ -10,12 +10,20 @@ namespace ck { ...@@ -10,12 +10,20 @@ namespace ck {
__device__ void block_sync_lds() __device__ void block_sync_lds()
{ {
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#ifdef __gfx12__
asm volatile("\
s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
// asm volatile("\ // asm volatile("\
// s_waitcnt lgkmcnt(0) \n \ // s_waitcnt lgkmcnt(0) \n \
// s_barrier \ // s_barrier \
// " ::); // " ::);
__builtin_amdgcn_s_waitcnt(0xc07f); __builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
#endif
#else #else
__syncthreads(); __syncthreads();
#endif #endif
...@@ -23,11 +31,20 @@ __device__ void block_sync_lds() ...@@ -23,11 +31,20 @@ __device__ void block_sync_lds()
__device__ void block_sync_lds_direct_load() __device__ void block_sync_lds_direct_load()
{ {
#ifdef __gfx12__
asm volatile("\
s_wait_vmcnt 0x0 \n \
s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
asm volatile("\ asm volatile("\
s_waitcnt vmcnt(0) \n \ s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \
s_barrier \ s_barrier \
" ::); " ::);
#endif
} }
__device__ void s_nop() __device__ void s_nop()
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/numeric/vector_type.hpp"
......
...@@ -26,7 +26,12 @@ struct __attribute__((packed)) buffer_resource ...@@ -26,7 +26,12 @@ struct __attribute__((packed)) buffer_resource
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff)
{ {
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
return __builtin_bit_cast(int32x4_t, res); int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
} }
namespace impl { namespace impl {
...@@ -49,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa ...@@ -49,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
} // namespace impl } // namespace impl
// TODO: glc/slc/... // TODO: glc/slc/...
template <index_t bytes> template <index_t bytes, bool pre_nop = false>
struct buffer_load; struct buffer_load;
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx)) // (exp_vector_type(xxx))
template <> template <bool pre_nop>
struct buffer_load<16> struct buffer_load<16, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<8> struct buffer_load<8, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<4> struct buffer_load<4, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<2> struct buffer_load<2, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<1> struct buffer_load<1, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <index_t bytes> template <index_t bytes, bool pre_nop = false>
struct buffer_load_if; struct buffer_load_if;
template <> template <bool pre_nop>
struct buffer_load_if<16> struct buffer_load_if<16, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T)); static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<8> struct buffer_load_if<8, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<4> struct buffer_load_if<4, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<2> struct buffer_load_if<2, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<1> struct buffer_load_if<1, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
...@@ -289,17 +379,16 @@ struct buffer_store<16> ...@@ -289,17 +379,16 @@ struct buffer_store<16>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t; using mbuf_t = fp32x4_t;
asm volatile( asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -310,17 +399,16 @@ struct buffer_store<8> ...@@ -310,17 +399,16 @@ struct buffer_store<8>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t; using mbuf_t = fp32x2_t;
asm volatile( asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -331,17 +419,16 @@ struct buffer_store<4> ...@@ -331,17 +419,16 @@ struct buffer_store<4>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = float;
asm volatile( asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -352,17 +439,16 @@ struct buffer_store<2> ...@@ -352,17 +439,16 @@ struct buffer_store<2>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 2); static_assert(sizeof(T) == 2);
using mbuf_t = short; using mbuf_t = short;
asm volatile( asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3"
"buffer_store_short %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -373,17 +459,16 @@ struct buffer_store<1> ...@@ -373,17 +459,16 @@ struct buffer_store<1>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = float;
asm volatile( asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -397,21 +482,20 @@ struct buffer_store_if<16> ...@@ -397,21 +482,20 @@ struct buffer_store_if<16>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t; using mbuf_t = fp32x4_t;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -426,7 +510,7 @@ struct buffer_store_if<8> ...@@ -426,7 +510,7 @@ struct buffer_store_if<8>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
...@@ -434,14 +518,13 @@ struct buffer_store_if<8> ...@@ -434,14 +518,13 @@ struct buffer_store_if<8>
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>; using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -456,21 +539,20 @@ struct buffer_store_if<4> ...@@ -456,21 +539,20 @@ struct buffer_store_if<4>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -485,21 +567,20 @@ struct buffer_store_if<2> ...@@ -485,21 +567,20 @@ struct buffer_store_if<2>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 2); static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = short; using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n" "buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -514,21 +595,20 @@ struct buffer_store_if<1> ...@@ -514,21 +595,20 @@ struct buffer_store_if<1>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -552,8 +632,9 @@ namespace impl{ ...@@ -552,8 +632,9 @@ namespace impl{
template<index_t N> template<index_t N>
CK_TILE_DEVICE void insert_dummy_dep_per_dword(array<float, N>& b) CK_TILE_DEVICE void insert_dummy_dep_per_dword(array<float, N>& b)
{ {
static_for<0, b.size(), 1>{}([&](auto i){ constexpr auto kSize = remove_cvref_t<decltype(b)>::size();
asm volatile(" " : : "v"(b.get(i)) : "memory"); static_for<0, kSize, 1>{}([&](auto i){
asm volatile(" " : : "v"(b.get(number<i>{})) : "memory");
}); });
} }
#if 1 #if 1
...@@ -895,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -895,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
CK_TILE_DEVICE void async_buffer_load_dword(void* smem, template <bool pre_nop = false>
int32x4_t rsrc, CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
index_t voffset, int32x4_t rsrc,
index_t soffset, index_t voffset,
index_t ioffset /*max 0xFFF*/, index_t /*soffset*/,
index_t /*flag*/ = 0) index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" if constexpr(pre_nop)
: "=r"(smem) /*dummy dependency for smem*/ asm volatile("s_nop 4\n"
: "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) "buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "memory"); : "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
else
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
} }
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
...@@ -1217,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe ...@@ -1217,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource, int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset, index_t src_wave_addr_offset,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
constexpr index_t bytes = sizeof(T) * N; constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
...@@ -1231,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1231,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
using type = thread_buffer<T, N>; using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check) if constexpr(oob_conditional_check)
{ {
buffer_load_if<sizeof(type)>{}( buffer_load_if<sizeof(type), pre_nop>{}(dst,
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
} }
else else
{ {
buffer_load<sizeof(type)>{}( buffer_load<sizeof(type), pre_nop>{}(dst,
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
} }
} }
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
int32x4_t src_wave_buffer_resource, int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset, index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0) index_t src_immediate_addr_offset = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
async_buffer_load_dword(smem, async_buffer_load_dword_v(smem,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
src_immediate_addr_offset); src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
} }
template <index_t N, template <index_t N,
...@@ -1903,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -1903,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
index_t is_valid_element = 0) index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check>( amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
} }
// unfortunately async copy can not make sure invalid data is zero inside LDS // unfortunately async copy can not make sure invalid data is zero inside LDS
...@@ -1925,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, ...@@ -1925,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
// buffer_load OOB still working. // buffer_load OOB still working.
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, bool pre_nop = false>
const T* p_src_wave, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
index_t src_thread_element_offset, const T* p_src_wave,
index_t src_element_space_size) index_t src_thread_element_offset,
index_t src_element_space_size,
bool_constant<pre_nop> = {})
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
...@@ -1937,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, ...@@ -1937,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
} }
// buffer_store requires: // buffer_store requires:
...@@ -2103,7 +2257,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -2103,7 +2257,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
asm volatile("s_mov_b32 m0, %0; \n\t" asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes), "v"(global_offset_bytes),
"s"(src_resource)); "s"(src_resource)
: "memory");
#else #else
// LDS pointer must be attributed with the LDS address space. // LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr = __attribute__((address_space(3))) uint32_t* lds_ptr =
......
...@@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } ...@@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds() CK_TILE_DEVICE void block_sync_lds()
{ {
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\ // asm volatile("\
s_waitcnt lgkmcnt(0) \n \ // s_waitcnt lgkmcnt(0) \n \
s_barrier \ // s_barrier \
" ::); // " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
#else #else
__syncthreads(); __syncthreads();
#endif #endif
...@@ -79,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load() ...@@ -79,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
" ::); " ::);
} }
CK_TILE_DEVICE void s_nop() CK_TILE_DEVICE void s_nop(index_t cnt = 0)
{ {
#if 1 #if 1
asm volatile("\ asm volatile("s_nop %0" : : "n"(cnt) :);
s_nop 0 \n \
" ::);
#else #else
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(cnt);
#endif #endif
} }
......
...@@ -17,7 +17,11 @@ ...@@ -17,7 +17,11 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__ #define __gfx11__
#endif #endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
...@@ -144,6 +148,14 @@ ...@@ -144,6 +148,14 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif #endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG #ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0 #define CK_TILE_DEBUG_LOG 0
#endif #endif
...@@ -155,7 +167,7 @@ ...@@ -155,7 +167,7 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code #elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) // for GPU code #elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
...@@ -167,6 +179,10 @@ ...@@ -167,6 +179,10 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif #endif
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif
// TODO: better solve this inside compiler // TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0 #define CK_TILE_FMHA_FWD_FAST_EXP2 0
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace ck_tile {
struct null_type
{
};
} // namespace ck_tile
...@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic, ...@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::generic; return address_space_enum::generic;
...@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global, ...@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
T* p_data_ = nullptr; T* p_data_ = nullptr;
BufferSizeType buffer_size_; BufferSizeType buffer_size_;
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0}; remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view() CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{} : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
{ {
} }
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
{ {
} }
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size, BufferSizeType buffer_size,
T invalid_element_value) T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} : p_data_{p_data},
buffer_size_{buffer_size},
cached_buf_res_{0},
invalid_element_value_{invalid_element_value}
{ {
} }
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE void init_raw()
{
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::global; return address_space_enum::global;
...@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global, ...@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const index_t i,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{ {
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global, ...@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>( amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
dst, p_data_, i, buffer_size_, is_valid_element); dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const index_t i,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{ {
// X is vector of T // X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global, ...@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>( amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
smem, p_data_, i, buffer_size_); smem, cached_buf_res_, i, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
...@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds, ...@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::lds; return address_space_enum::lds;
...@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr, ...@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::vgpr; return address_space_enum::vgpr;
......
...@@ -36,30 +36,37 @@ template <typename T, ...@@ -36,30 +36,37 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}); tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord> index_t NumCoord,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile, async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window) NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
return tile_window.async_load(lds_tile); return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
...@@ -35,6 +35,8 @@ struct null_tile_window ...@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
CK_TILE_DEVICE void init_raw() {}
WindowLengths window_lengths_; WindowLengths window_lengths_;
}; };
......
...@@ -36,6 +36,8 @@ struct tensor_view ...@@ -36,6 +36,8 @@ struct tensor_view
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
...@@ -85,30 +87,34 @@ struct tensor_view ...@@ -85,30 +87,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X // "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE void CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
get_vectorized_elements_raw(remove_cvref_t<X>& dst, const TensorCoord& coord,
const TensorCoord& coord, bool_constant<oob_conditional_check> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<pre_nop> = {}) const
{ {
return buf_.template get_raw<X, oob_conditional_check>( return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst, dst,
coord.get_offset(), coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
} }
template <typename X, template <typename X,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem, CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
const TensorCoord& coord) const remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
{ {
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/); return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
} }
// X is vector of DataType. // X is vector of DataType.
......
...@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&) ...@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor... // sub-dword tensor...
template <typename DstrTensors, index_t v> template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>) CK_TILE_DEVICE void
set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
{ {
constexpr index_t tensor_bytes = using elem_type = typename DstrTensors::DataType;
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); constexpr index_t elem_size = sizeof(elem_type);
if constexpr(v == 0 && tensor_bytes % 4 == 0)
constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
// # bytes per write = 4
if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
{ {
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto& buffer = dstr_tensor.get_thread_buffer();
static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
if constexpr(elem_size == 1)
{
// # elements per write = 4
constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
buffer[i_write * 4 + 0] = values.x;
buffer[i_write * 4 + 1] = values.y;
buffer[i_write * 4 + 2] = values.z;
buffer[i_write * 4 + 3] = values.w;
}
else if constexpr(elem_size == 2)
{
// # elements per write = 2
constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
buffer[i_write * 2 + 0] = values.x;
buffer[i_write * 2 + 1] = values.y;
}
else if constexpr(elem_size == 4)
{
// # elements per write = 1
constexpr elem_type value = 0;
buffer[i_write] = value;
}
else
{
static_assert(false, "type not supported");
}
});
#else
using dvec_t = array<index_t, tensor_bytes / 4>; using dvec_t = array<index_t, tensor_bytes / 4>;
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer()); auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
for(auto i = 0; i < tensor.size(); i++) for(auto i = 0; i < tensor.size(); i++)
tensor.get(i) = v; tensor.get(i) = v;
#endif
} }
else else
{ {
tile_elementwise_inout( tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); }, dstr_tensor);
dstr_tensor);
} }
} }
...@@ -110,7 +150,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) ...@@ -110,7 +150,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace impl { namespace impl {
// TODO: this is ugly // TODO: this is ugly
template <typename OutDataType, typename InTensor> template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function // This API is designed to use the _pk_ serious of function
...@@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) ...@@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif #endif
} }
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
static_assert(thread_buffer_size % 2 == 0);
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
// TODO: this is rtz cvt, need be very careful
for(index_t i = 0; i < thread_buffer_size_pk; i++)
{
auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
}
return out_dstr_tensor;
#else
// fallback
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
in_dstr_tensors);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST #if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword // this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...@@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) ...@@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float> && float> &&
(SrcTensor::get_thread_buffer_size() % 4 == 0)) (SrcTensor::get_thread_buffer_size() % 4 == 0))
{ {
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor); return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
} }
#if CK_TILE_USE_PK_FP16_TILE_CAST
else if constexpr(std::is_same_v<DstType, fp16_t> &&
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 2 == 0))
{
return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST #if CK_TILE_USE_SUBDWORD_TILE_CAST
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
{ {
......
...@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution ...@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, bool oob_conditional_check = true> template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution ...@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
...@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution ...@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>( get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(), dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord, bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{},
pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution ...@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution
} }
}); });
}); });
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
} }
// TODO: currently async load only implemented in inline asm // TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true> template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{ {
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>; using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView; // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
...@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution ...@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// read from bottom tensor // read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>( get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord); smem, bottom_tensor_thread_coord, pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution ...@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution
}); });
} }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view // this is the bottom tensor view
// [x0', x1', ...] ==> [offset] // [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_; BottomTensorView bottom_tensor_view_;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
......
...@@ -56,8 +56,9 @@ check_err(const Range& out, ...@@ -56,8 +56,9 @@ check_err(const Range& out,
} }
const auto is_infinity_error = [=](auto o, auto r) { const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
}; };
...@@ -114,8 +115,9 @@ check_err(const Range& out, ...@@ -114,8 +115,9 @@ check_err(const Range& out,
} }
const auto is_infinity_error = [=](auto o, auto r) { const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
}; };
...@@ -173,8 +175,9 @@ check_err(const Range& out, ...@@ -173,8 +175,9 @@ check_err(const Range& out,
} }
const auto is_infinity_error = [=](auto o, auto r) { const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
}; };
...@@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
} }
const auto is_infinity_error = [=](auto o, auto r) { const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
}; };
...@@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
} }
const auto is_infinity_error = [=](auto o, auto r) { const auto is_infinity_error = [=](auto o, auto r) {
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); const bool both_infinite_and_same =
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment