"...composable_kernel.git" did not exist on "adf2e4b13c20e9b6b2c1cd0bf4775df3debce6ce"
Commit 7a08aec6 authored by Chao Liu's avatar Chao Liu
Browse files

move buffer load and store from threadwise copy into DynamicBuffer

parent 510b3a21
...@@ -185,9 +185,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -185,9 +185,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_a_global); const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_b_global); p_a_global, a_k_m_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_c_global); const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_global_desc.GetLength(I1);
...@@ -361,13 +364,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -361,13 +364,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double); auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double); p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_k_n_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_space_size); p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_space_size); p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
......
...@@ -84,9 +84,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -84,9 +84,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_a_global); const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_b_global); p_a_global, a_e_k_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_c_global); const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
...@@ -223,7 +226,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -223,7 +226,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
true>(b_e_n_ho_wo_global_desc, true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(p_shared_block); auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(p_shared_block,
a_e_k_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
StaticBuffer<AddressSpace::Vgpr, FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> StaticBuffer<AddressSpace::Vgpr, FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
......
...@@ -192,12 +192,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -192,12 +192,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
return dst_data_idx; return dst_data_idx;
}(); }();
// copy data
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector; typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type; typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
...@@ -209,34 +209,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -209,34 +209,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_coord_); dst_desc, dst_slice_origin_coord_);
if constexpr(SrcBuffer::GetAddressSpace() == AddressSpace::Vgpr && // copy data from dst_vector into dst_buf
DstBuffer::GetAddressSpace() == AddressSpace::Global) dst_buf.template Set<dst_vector_t>(
{ dst_slice_origin_coord_.GetOffset(),
#if CK_USE_AMD_BUFFER_ADDRESSING is_dst_valid,
amd_buffer_store_v2<DstData, DstScalarPerVector>( dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
dst_vector.template AsType<dst_vector_t>()(Number<0>{}),
dst_buf.p_data_,
dst_slice_origin_coord_.GetOffset(),
is_dst_valid,
dst_desc.GetElementSpaceSize());
#else
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(
&(dst_buf.p_data_[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
}
#endif
}
else
{
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(
&(dst_buf.p_data_[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
}
}
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -536,7 +513,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -536,7 +513,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
return src_data_idx; return src_data_idx;
}(); }();
// copy data
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = using src_vector_t =
...@@ -545,30 +521,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -545,30 +521,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_slice_origin_coord_);
if constexpr(SrcBuffer::GetAddressSpace() == AddressSpace::Global) // copy data from src_buf into src_vector
{ src_vector.template AsType<src_vector_t>()(Number<0>{}) =
#if CK_USE_AMD_BUFFER_ADDRESSING src_buf.template Get<src_vector_t>(src_slice_origin_coord_.GetOffset(),
src_vector.template AsType<src_vector_t>()(Number<0>{}) = is_src_valid);
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
src_buf.p_data_,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&src_buf.p_data_[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&src_buf.p_data_[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
...@@ -878,7 +836,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -878,7 +836,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
return src_data_idx; return src_data_idx;
}(); }();
// copy data from src_buf to src_tmp_vector
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector; vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
...@@ -886,29 +843,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -886,29 +843,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_slice_origin_coord_);
if constexpr(SrcBuffer::GetAddressSpace() == AddressSpace::Global) // copy data from src_buf to src_tmp_vector
{ src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
#if CK_USE_AMD_BUFFER_ADDRESSING src_buf.template Get<src_vector_t>(src_slice_origin_coord_.GetOffset(),
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = is_src_valid);
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
src_buf.p_data_,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&src_buf.p_data_[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&src_buf.p_data_[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
// copy data from src_tmp_vector to buffer_ // copy data from src_tmp_vector to buffer_
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
...@@ -1068,13 +1006,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1068,13 +1006,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
return dst_data_idx; return dst_data_idx;
}(); }();
// copy data
// hardcoding for ds_write
// TODO refactor transfer_data() to encapsulate this
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Lds &&
DstInMemOp == InMemoryDataOperation::Set,
"wrong! hardcoded for ds_write");
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector; vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
// copy data from buffer_ to dst_tmp_vector // copy data from buffer_ to dst_tmp_vector
...@@ -1088,8 +1019,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1088,8 +1019,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
using dst_vector_t = typename decltype(dst_tmp_vector)::type; using dst_vector_t = typename decltype(dst_tmp_vector)::type;
// copy data from dst_tmp_vector to dst_buf // copy data from dst_tmp_vector to dst_buf
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_coord_);
dst_buf.template Set<dst_vector_t>( dst_buf.template Set<dst_vector_t>(
dst_slice_origin_coord_.GetOffset(), dst_slice_origin_coord_.GetOffset(),
is_dst_valid,
dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}]); dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
...@@ -1499,7 +1434,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1499,7 +1434,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
move_dynamic_tensor_coordinate( move_dynamic_tensor_coordinate(
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
// copy data from src_buf into src_tmp_buffer
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector; vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
...@@ -1507,9 +1441,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1507,9 +1441,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord); src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? src_buf.template Get<src_vector_t>(src_data_coord.GetOffset()) src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
: src_vector_t{0};
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData) // DstData)
......
...@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return __llvm_amdgcn_raw_buffer_load_i8x2( return __llvm_amdgcn_raw_buffer_load_i8x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else #else
...@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return __llvm_amdgcn_raw_buffer_load_i8x4( return __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else #else
...@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 8> tmp; vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
...@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 16) else if constexpr(N == 16)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 16> tmp; vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
...@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, __llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
...@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, __llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
......
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
#include "container_element_picker.hpp" #include "container_element_picker.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "float_type.hpp" #include "float_type.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
...@@ -26,6 +24,8 @@ ...@@ -26,6 +24,8 @@
#include "type.hpp" #include "type.hpp"
#include "utility.hpp" #include "utility.hpp"
#include "magic_division.hpp" #include "magic_division.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
......
...@@ -143,8 +143,8 @@ ...@@ -143,8 +143,8 @@
#endif #endif
// workaround for compiler crash when using buffer load/store for i8 // workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX #ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif #endif
// workaround for compiler crash when using buffer load/store for i8 // workaround for compiler crash when using buffer load/store for i8
......
...@@ -3,14 +3,20 @@ ...@@ -3,14 +3,20 @@
namespace ck { namespace ck {
template <AddressSpace BufferAddressSpace, typename T> #include "amd_buffer_addressing_v2.hpp"
template <AddressSpace BufferAddressSpace, typename T, typename ElementSpaceSize>
struct DynamicBuffer struct DynamicBuffer
{ {
using type = T; using type = T;
T* p_data_; T* p_data_;
ElementSpaceSize element_space_size_;
__host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {} __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
}
__host__ __device__ static constexpr AddressSpace GetAddressSpace() __host__ __device__ static constexpr AddressSpace GetAddressSpace()
{ {
...@@ -26,13 +32,33 @@ struct DynamicBuffer ...@@ -26,13 +32,33 @@ struct DynamicBuffer
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr const auto Get(index_t i) const __host__ __device__ constexpr const auto Get(index_t i, bool is_valid_offset) const
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE // X contains multiple T
return *reinterpret_cast<const X*>(&p_data_[i]); constexpr index_t scalar_per_t_vector =
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector =
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
p_data_, i, is_valid_offset, element_space_size_);
#else #else
return *reinterpret_cast<const X*>(&p_data_[i]); return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
#endif #endif
}
else
{
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
}
} }
template <typename X, template <typename X,
...@@ -40,34 +66,74 @@ struct DynamicBuffer ...@@ -40,34 +66,74 @@ struct DynamicBuffer
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ void Set(index_t i, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE // X contains multiple T
*reinterpret_cast<X*>(&p_data_[i]) = x; constexpr index_t scalar_per_t_vector =
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector =
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
x, p_data_, i, is_valid_offset, element_space_size_);
#else #else
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type, if(is_valid_offset)
int8_t>::value) {
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
#endif
}
else if constexpr(GetAddressSpace() == AddressSpace::Lds)
{ {
static_assert(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value && if(is_valid_offset)
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value,
"wrong! not implemented for this combination, please add implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{ {
#if 0 #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*reinterpret_cast<int32x4_t*>(&p_data_[i]) = as_type<int32x4_t>(x); *reinterpret_cast<X*>(&p_data_[i]) = x;
#else #else
*reinterpret_cast<int32x4_t*>(&p_data_[i]) = // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
*reinterpret_cast<const int32x4_t*>(&x); // ISA, so I try to let compiler emit use IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
int8_t>::value)
{
static_assert(
is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value,
"wrong! not implemented for this combination, please add implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{
// HACK: compiler would emit IR "store<i32, 4>" if using this
// TODO: remove this after compiler fix
*reinterpret_cast<int32x4_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x4_t*>(&x);
}
}
else
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
#endif #endif
} }
} }
else else
{ {
*reinterpret_cast<X*>(&p_data_[i]) = x; if(is_valid_offset)
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
} }
#endif
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
...@@ -75,10 +141,12 @@ struct DynamicBuffer ...@@ -75,10 +141,12 @@ struct DynamicBuffer
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
}; };
template <AddressSpace BufferAddressSpace = AddressSpace::Generic, typename T> template <AddressSpace BufferAddressSpace = AddressSpace::Generic,
__host__ __device__ constexpr auto make_dynamic_buffer(T* p) typename T,
typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{ {
return DynamicBuffer<BufferAddressSpace, T>{p}; return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
} }
} // namespace ck } // namespace ck
......
...@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#if 0 #if 0
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c0_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, C0, Hi, Wi));
const auto wei_k_c_y_x_desc = const auto wei_k_c0_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, C0, Y, X));
const auto out_n_k_ho_wo_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, K0, Ho, Wo, K1));
const auto conv_strides = to_multi_index(ConvStrides{}); const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{}); const auto conv_dilations = to_multi_index(ConvDilations{});
......
...@@ -92,7 +92,7 @@ int main(int argc, char* argv[]) ...@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 0 #if 1
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -740,7 +740,7 @@ int main(int argc, char* argv[]) ...@@ -740,7 +740,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment