Commit edc08fe6 authored by Chao Liu's avatar Chao Liu
Browse files

static kernel use raw buffer load/store

parent ecad4061
......@@ -173,14 +173,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
}();
// copy data
// hardcoding for buffer_store
// TODO refactor transfer_data() to encapsulate this
static_assert(SrcAddressSpace == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global,
"wrong! hardcoded to use buffer_store");
vector_type<DstData, DstScalarPerVector> dst_vector;
using dst_vector_t = typename vector_type<DstData, DstScalarPerVector>::MemoryType;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset =
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
......@@ -189,13 +185,35 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
dst_vector.Scalars()(i) = p_src[Number<src_offset>{}];
});
amd_buffer_store_v2<DstData, DstScalarPerVector>(
dst_vector.Vector(),
p_dst,
dst_slice_origin_coord_.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_coord_),
dst_desc.GetElementSpaceSize());
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<DstData, DstScalarPerVector>(
dst_vector.Vector(),
p_dst,
dst_slice_origin_coord_.GetOffset(),
is_dst_valid,
dst_desc.GetElementSpaceSize());
#else
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(&(p_dst[dst_slice_origin_coord_.GetOffset])) =
dst_vector.Vector();
}
#endif
}
else
{
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(&(p_dst[dst_slice_origin_coord_.GetOffset])) =
dst_vector.Vector();
}
}
constexpr auto move_on_dim = [&]() constexpr
{
......@@ -482,33 +500,36 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}();
// copy data
// hardcoding for buffer_store
// TODO refactor transfer_data() to encapsulate this
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for ds_read");
vector_type<SrcData, SrcScalarPerVector> src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::MemoryType;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Global)
{
const bool is_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_valid,
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
const bool is_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
src_vector.Vector() = is_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
......@@ -815,23 +836,35 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
}();
// copy data
// hardcoding for buffer_load
// TODO refactor transfer_data() to encapsulate this
static_assert(SrcAddressSpace == AddressSpace::Global,
"wrong! hardcoded to use buffer_load, src must be global mem");
vector_type<SrcData, SrcScalarPerVector> src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::MemoryType;
const bool is_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
src_vector.Vector() =
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(p_src,
src_slice_origin_coord_.GetOffset(),
is_valid,
src_desc.GetElementSpaceSize());
if constexpr(SrcAddressSpace == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset =
......
......@@ -89,7 +89,7 @@ struct SetData
if(dst_valid)
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
amd_buffer_load<T, DataPerAccess>(p_src, src_offset, src_valid, src_range);
amd_buffer_load_v2<T, DataPerAccess>(p_src, src_offset, src_valid, src_range);
}
}
......@@ -109,12 +109,12 @@ struct SetData
{
const auto zeros = vector_t(0);
amd_buffer_store<T, DataPerAccess>(src_valid ? &(p_src[src_offset])
: reinterpret_cast<const T*>(&zeros),
p_dst,
dst_offset,
dst_valid,
dst_range);
amd_buffer_store_v2<T, DataPerAccess>(
src_valid ? *reinterpret_cast<const vector_t*>(&(p_src[src_offset])) : zeros,
p_dst,
dst_offset,
dst_valid,
dst_range);
}
#endif
};
......
......@@ -67,7 +67,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif
#if 1
#if 0
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
......
......@@ -674,7 +674,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 1
#elif 0
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
......@@ -686,7 +686,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 0
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment