Unverified Commit 4396a224 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan Committed by GitHub
Browse files

Merge branch 'develop' into mi300_time_measurement

parents 0a27f07e 501a6b68
......@@ -439,7 +439,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
MakeBGridDescriptor_N_K(const index_t NRaw, const index_t KRaw, const index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
......@@ -463,15 +463,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& KRaws,
const std::array<index_t, NumBTensor>& NRaws,
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& NRaws,
const std::array<index_t, NumBTensor>& KRaws,
const std::array<index_t, NumBTensor>& BsStride)
{
return generate_tuple(
[&](auto i) {
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(KRaws[i], NRaws[i], BsStride[i]);
return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(NRaws[i], KRaws[i], BsStride[i]);
},
Number<NumBTensor>{});
}
......@@ -574,7 +574,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
......@@ -595,8 +594,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
#if 0
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
......@@ -626,8 +627,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
#if 0
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
......
......@@ -10,38 +10,9 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t VectorDim, index_t ScalarPerVector>
struct lambda_scalar_per_access
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? ScalarPerVector : 1;
}
};
template <index_t VectorDim>
struct lambda_scalar_step_in_vector
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? 1 : 0;
}
};
} // namespace detail
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t VectorDim, index_t ScalarPerVector>
struct lambda_scalar_per_access
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? ScalarPerVector : 1;
}
};
template <index_t VectorDim>
struct lambda_scalar_step_in_vector
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? 1 : 0;
}
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -7,43 +7,12 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
} // namespace detail
namespace ck {
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -8,9 +8,11 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace ck {
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
......@@ -70,16 +72,18 @@ struct ThreadwiseTensorSliceTransfer_v7r2
static constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
static constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>,
false>;
using DstSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DstDimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
remove_cv_t<decltype(dst_scalar_per_access)>,
false>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r2(
const SrcDescs& src_descs,
......@@ -139,9 +143,9 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
{
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
......@@ -199,7 +203,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
......@@ -214,10 +218,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
unpack2(element_op_, dst_data_refs, src_data_refs);
});
dst_vectors_tuple_(iAccess) = dst_vectors;
elm_vectors_tuple_(iAccess) = elm_vectors;
// move coordinate
if constexpr(iAccess.value != num_access - 1)
if constexpr(iAccess.value != src_num_access - 1)
{
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
......@@ -241,15 +245,113 @@ struct ThreadwiseTensorSliceTransfer_v7r2
});
}
__device__ void TransposeFromElmToDst()
{
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
using SrcThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
SrcScalarPerVector,
decltype(GetSrcThreadScratchDescriptor()),
true>;
using DstThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
DstScalarPerVector,
decltype(GetDstThreadScratchDescriptor()),
true>;
SrcThreadScratch elm_thread_scratch_;
DstThreadScratch dst_thread_scratch_;
elm_thread_scratch_.data_ =
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_);
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
constexpr auto data_idx = access_idx * scalar_per_access;
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim
return elm_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector);
},
Number<num_src_vector>{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto dst_vector_refs = generate_tie(
[&](auto i) -> dst_vector_t& {
// i increment corresponds to movement in SrcVectorDim
return dst_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * src_scalar_step_in_vector);
},
Number<num_dst_vector>{});
// do data transpose
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
else
{
static_ford<SliceLengths>{}(
[&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
}
dst_vectors_tuple_ = bit_cast<decltype(dst_vectors_tuple_)>(dst_thread_scratch_.data_);
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers,
enable_if_t<DstDescs::Size() == DstBuffers::Size(), bool> = false>
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
{
TransposeFromElmToDst();
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[iAccess];
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[Number<iAccess>{}];
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
......@@ -269,7 +371,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
});
// move coordinate
if constexpr(iAccess.value != num_access - 1)
if constexpr(iAccess.value != dst_num_access - 1)
{
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
......@@ -312,28 +414,126 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
if constexpr(num_access == 0)
if constexpr(src_num_access == 0)
{
return typename SrcSpaceFillingCurve::Index{};
}
else
{
return SrcSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return SrcSpaceFillingCurve::GetStepBetween(Number<src_num_access - 1>{}, Number<0>{});
}
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
if constexpr(num_access == 0)
if constexpr(dst_num_access == 0)
{
return typename DstSpaceFillingCurve::Index{};
}
else
{
return DstSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
}
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
// constexpr auto src_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
// constexpr auto dst_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t ISrc>
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
......@@ -372,11 +572,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
private:
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<DstVectorsType, num_access> dst_vectors_tuple_;
StaticallyIndexedArray<ElmVectorsType, src_num_access> elm_vectors_tuple_;
StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_;
SrcCoords src_coords_;
DstCoords dst_coords_;
......
# ck_tile
## concept
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
- tile-based programming model, including tile-level api and the concept of distributed tensor.
`ck_tile` is independently from the old ck, located under [/include/ck_tile](/include/ck_tile). You don't need to include anything from old CK, `ck_tile` has similiar (indeed almost the same) implementations for users to build operators. We will have a transition period to pull everything from old ck into `ck_tile`, stay tuned.
## component
`ck_tile` is splitted into several componenets including `core`, `host`, `ops/gemm`, `ops/fmha`... each component you only need to include a single header (e.g `#include "ck_tile/core.hpp"`, `#include "ck_tile/ops/fmha.hpp"`) then you are able to use the function/structure inside (different from old `ck`)
**[core]**
`ck_tile/core` contains all the basic data structure and function to build the kernel, you can only include this header and build your own operators that utilizing all the basic building blocks introduced in ck.
`core/container`
- array, store runtime variables with fixed length (tensor index, register buffer, etc...)
- tuple, same as std::tuple, hold different type of data, and one of the solution to achieve multiple buffer.
- sequence, compile time integer sequence used to build various internal structures, or to describe tile size
- other convenient structure build on top of above 3
`core/numeric`
- gpu data type like `fp16_t`, `bf16_t`, `fp8_t`... and the conversion between each other
- constexpr integer similiar to std::integral_constant to be used as compile time integer.
- math functions and numeric utilities
`core/algorithm`
- coordinate transformation system, used to build tensor transform and compile time indexing. This is the core idea introduced in old `ck` to describe how a tensor is build by several basic transform primitives like `merge`/`unmerge`/`embed` etc... and how we indexing into a ND tensor that finally mapped to 1D memory offset.
`core/tensor`
- tensor descriptor, to describe how a ND tensor
- distributed tensor, describe the storage of this tensor, and the distribution of how a collection of threads collaborately work for this tensor.
- tile level API, including `load_tile`, `store_tile`, `shuffle_tile`, `slice_tile`, etc...
**[host]**
`ck_tile/host` contains all the host side utilities to launch a kernel, create the device buffer, and some reference implementations. This can be used to create examples (like that under ck_tile example folder) and simple executable to invoke this kernel, so if you only need `ck_tile` to build your own device library then it's OK to not include this. Based on this, it is recommended to include the specific header you needed under this folder to avoid including unwanted headers (e.g, only include `ck_tile/host/kernel_launch.hpp`), unless you are writing a host executable.
**[ops/gemm, ops/fmha, ops/reduce...]**
our implementation of different device operators.
- warp, warp tile level operator
- block, block tile level operator
- pipeline, pipeline that can achieve a customized tile level mainloop (or epilogue). By switching different pipeline to the kernel template you can have different kind of pipeline optimizations.
- kernel, template interface for users to instantiate a particular kernel
**[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
## examples
currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/meta_data_buffer.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/span.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
#include "ck_tile/core/tensor/load_tile.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/shuffle_tile.hpp"
#include "ck_tile/core/tensor/slice_tile.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/store_tile.hpp"
#include "ck_tile/core/tensor/sweep_tile.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/tensor/tensor_coordinate.hpp"
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/unary_element_function.hpp"
# ck_tile/core #
`ck_tile/core` contains every basic functions and structures to create a GPU kernel using `ck_tile`. User should only include `ck_tile/core.hpp` this single header to use all the functionality. Everything is under `ck_tile` namespace. The coding style under this folder should be similar to `std` (`snake_case` for structure/function, Camel for template types...)
```
algorithm/
coordinate transform and some other reusable algorithm
arch/
contains some basic device building block like mma, buffer addressing, etc...
container/
contains basic container data structure, array/sequence/tuple/...
numeric/
data type, and data type related math
tensor/
tensor descriptors and tile level API
utility/
other utility function for both host/device
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
CK_TILE_HOST_DEVICE constexpr auto make_cluster_descriptor(
const Lengths& lengths,
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type{})
{
constexpr index_t ndim_low = Lengths::size();
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
const auto low_lengths = generate_tuple(
[&](auto idim_low) { return reordered_lengths[idim_low]; }, number<ndim_low>{});
const auto transform = make_merge_transform(low_lengths);
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
constexpr auto up_dim_new_top_ids = sequence<0>{};
return make_single_stage_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
namespace ck_tile {
enum struct coord_transform_enum
{
undefined,
pass_through,
pad,
embed,
merge,
unmerge,
replicate,
xor_t,
offset,
};
template <index_t NDimLow, index_t NDimUp>
struct base_transform
{
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::undefined;
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return NDimLow; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return NDimUp; }
// return safe value for vector length/stride, based on compile-time known only
// variables
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths&,
const LowVectorStrides&)
{
if constexpr(NDimUp > 0)
{
array<index_t, NDimUp> up_vector_lengths{-1};
array<index_t, NDimUp> up_vector_strides{-1};
return make_tuple(up_vector_lengths, up_vector_strides);
}
else
{
return make_tuple(array<index_t, 0>{}, array<index_t, 0>{});
}
}
};
template <typename LowLength>
struct pass_through : public base_transform<1, 1>
{
static constexpr auto type_enum = coord_transform_enum::pass_through;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{}));
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr pass_through() = default;
CK_TILE_HOST_DEVICE constexpr pass_through(const LowLength& low_length)
: up_lengths_{make_tuple(low_length)}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::pass_through;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up)
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}];
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pass_through{");
//
printf("up_lengths_:");
print(up_lengths_);
//
printf("}");
}
};
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
bool SkipIsValidCheck = false>
struct pad : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
UpLengths up_lengths_;
LeftPadLength left_pad_length_;
RightPadLength right_pad_length_;
CK_TILE_HOST_DEVICE constexpr pad() : up_lengths_{}, left_pad_length_{}, right_pad_length_{} {}
CK_TILE_HOST_DEVICE constexpr pad(const LowLength& low_length,
const LeftPadLength& left_pad_length,
const RightPadLength& right_pad_length)
: up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
left_pad_length_{left_pad_length},
right_pad_length_{right_pad_length}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return SkipIsValidCheck;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const
{
return SkipIsValidCheck ||
((idx_up[number<0>{}] >= left_pad_length_) &&
(idx_up[number<0>{}] < up_lengths_[number<0>{}] - right_pad_length_));
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
ck_tile::is_known_at_compile_time<RightPadLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
struct left_pad
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
UpLengths up_lengths_;
LeftPadLength left_pad_length_;
CK_TILE_HOST_DEVICE constexpr left_pad() = default;
CK_TILE_HOST_DEVICE constexpr left_pad(const LowLength& low_length,
const LeftPadLength& left_pad_length)
: up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return SkipIsValidCheck;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const
{
return SkipIsValidCheck || (idx_up[number<0>{}] >= left_pad_length_);
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<LeftPadLength>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
// TODO: we allow pass through this vector length. If one need per-pixel check,
// should change the guaranteed vector length while creating the tensor view.
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("left_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf("}");
}
};
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct right_pad : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
UpLengths up_lengths_;
LowLength low_length_;
RightPadLength right_pad_length_;
CK_TILE_HOST_DEVICE constexpr right_pad() = default;
CK_TILE_HOST_DEVICE constexpr right_pad(const LowLength& low_length,
const RightPadLength& right_pad_length)
: up_lengths_{make_tuple(low_length + right_pad_length)},
low_length_{low_length},
right_pad_length_{right_pad_length}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up)
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}];
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return SkipIsValidCheck;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const
{
return SkipIsValidCheck || (idx_up[number<0>{}] < low_length_);
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<LowLength>::value &&
ck_tile::is_known_at_compile_time<RightPadLength>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
// TODO: we allow pass through this vector length. If one need per-pixel check,
// should change the guaranteed vector length while creating the tensor view.
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("right_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
// 2) Tuple of number, which is known at compile-time, or
// 3) Tuple of mixture of index_t and number, which is known partially at run-time and partially
// at compile-time
template <typename UpLengths,
typename Coefficients,
typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
struct embed : public base_transform<1, UpLengths::size()>
{
static constexpr index_t NDimUp = UpLengths::size();
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<NDimUp>;
UpLengths up_lengths_;
Coefficients coefficients_;
CK_TILE_HOST_DEVICE constexpr embed() = default;
CK_TILE_HOST_DEVICE constexpr embed(const UpLengths& up_lengths,
const Coefficients& coefficients)
: up_lengths_{up_lengths}, coefficients_{coefficients}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::embed;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = 0;
static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
idx_low(number<0>{}) += idx_up[i] * this->coefficients_[i];
});
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&) const
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = 0;
static_for<0, NDimUp, 1>{}(
[&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Coefficients>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("coefficients_: ");
print(coefficients_);
printf("}");
}
};
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
template <index_t I>
CK_TILE_HOST_DEVICE constexpr auto operator()(number<I> i) const
{
return magic_division::calculate_magic_numbers(LowLengths{}[i]);
}
};
// Implementation of "merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template <typename LowLengths>
struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
{
static constexpr index_t NDimLow = LowLengths::size();
using LowerIndex = multi_index<NDimLow>;
using UpperIndex = multi_index<1>;
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
using LowLengthsMagicDivisor = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_divisor<LowLengths>{},
number<NDimLow>{}));
LowLengths low_lengths_;
LowLengthsMagicDivisor low_lengths_magic_divisor_;
UpLengths up_lengths_;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division() = default;
CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_magic_divisor_{generate_tuple(
[&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); },
number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))}
{
static_assert(LowerIndex::size() == NDimLow, "wrong!");
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::merge;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[I0];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
magic_division::do_magic_division(tmp,
this->low_lengths_magic_divisor_[i][I0],
this->low_lengths_magic_divisor_[i][I1]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
});
idx_low(number<0>{}) = tmp;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up_new) const
{
static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
LowIdx::size() == NDimLow && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up_new[number<0>{}];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
magic_division::do_magic_division(tmp,
this->low_lengths_magic_divisor_[i][I0],
this->low_lengths_magic_divisor_[i][I1]);
index_t idx_low_old = idx_low[i];
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
idx_diff_low(i) = idx_low[i] - idx_low_old;
});
idx_diff_low(number<0>{}) = tmp - idx_low(number<0>{});
idx_low(number<0>{}) = tmp;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<LowLengths>::value &&
ck_tile::is_known_at_compile_time<LowLengthsMagicDivisor>::value &&
ck_tile::is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
array<index_t, 1> up_vector_lengths{-1};
array<index_t, 1> up_vector_strides{-1};
up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("merge_v2_magic_division{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template <typename LowLengths>
struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
{
static constexpr index_t NDimLow = LowLengths::size();
using LowerIndex = multi_index<NDimLow>;
using UpperIndex = multi_index<1>;
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{}));
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod() = default;
CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, multiplies{}, number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))}
{
static_assert(LowerIndex::size() == NDimLow, "wrong!");
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[number<0>{}];
// division and mod
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i];
tmp %= this->low_lengths_scan_[i];
});
idx_low(number<NDimLow - 1>{}) = tmp;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up_new) const
{
static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
LowIdx::size() == NDimLow && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
constexpr auto INm1 = number<NDimLow - 1>{};
index_t tmp = idx_up_new[I0];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
const index_t tmp2 = idx_low[i];
idx_low(i) = tmp / this->low_lengths_scan_[i];
idx_diff_low(i) = idx_low[i] - tmp2;
tmp %= this->low_lengths_scan_[i];
});
const index_t tmp2 = idx_low[INm1];
idx_low(INm1) = tmp;
idx_diff_low(INm1) = idx_low[INm1] - tmp2;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<LowLengths>::value &&
ck_tile::is_known_at_compile_time<LowLengthsScan>::value &&
ck_tile::is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
array<index_t, 1> up_vector_lengths{-1};
array<index_t, 1> up_vector_strides{-1};
up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Merge_v3_direct_division_mod{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ ");
print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation>
struct unmerge : public base_transform<1, UpLengths::size()>
{
static constexpr index_t NDimUp = UpLengths::size();
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<NDimUp>;
using UpLengthsScan =
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{}));
UpLengths up_lengths_;
UpLengthsScan up_lengths_scan_;
CK_TILE_HOST_DEVICE constexpr unmerge() = default;
CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths)
: up_lengths_{up_lengths},
up_lengths_scan_{container_reverse_exclusive_scan(up_lengths, multiplies{}, number<1>{})}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::unmerge;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
if constexpr(!Use24BitIntegerCalculation)
{
idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
static_for<0, NDimUp - 1, 1>{}(
[&](auto i) { idx_low(number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
}
else
{
idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
static_for<0, NDimUp - 1, 1>{}([&](auto i) {
idx_low(number<0>{}) =
(0x00ffffff & idx_low[number<0>{}]) +
(0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]);
});
}
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&) const
{
calculate_lower_index(idx_diff_low, idx_diff_up);
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<UpLengthsScan>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
array<index_t, NDimUp> up_vector_lengths{-1};
array<index_t, NDimUp> up_vector_strides{-1};
constexpr auto up_length_last = UpLengths{}[number<NDimUp - 1>{}];
if constexpr(ck_tile::is_known_at_compile_time<decltype(up_length_last)>::value)
{
if(low_vector_lengths[0] != -1)
{
up_vector_lengths(NDimUp - 1) = gcd(low_vector_lengths[0], up_length_last);
}
}
up_vector_strides(NDimUp - 1) = low_vector_strides[0];
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("unmerge{");
//
printf("up_lengths_");
print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_");
print(up_lengths_scan_);
printf("}");
}
};
template <typename LowerIndex>
struct freeze : public base_transform<1, 0>
{
LowerIndex low_idx_;
CK_TILE_HOST_DEVICE constexpr freeze() = default;
CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return tuple<>{}; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& /* idx_up */) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 0,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = low_idx_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /* idx_diff_up */,
LowIdx& /* idx_low */,
const UpIdx& /* idx_up_new */)
{
idx_diff_low(number<0>{}) = 0;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("freeze{");
//
printf("low_idx_: ");
print(low_idx_);
printf("}");
}
};
// insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct insert : public base_transform<0, 1>
{
using UpLengths = decltype(make_tuple(UpperLength{}));
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr insert() = default;
CK_TILE_HOST_DEVICE constexpr insert(const UpperLength& up_length)
: up_lengths_{make_tuple(up_length)}
{
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return 0; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return 1; }
CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
{
static_assert(LowIdx::size() == 0 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void
update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
{
static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == 1 && LowIdx::size() == 0 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
}
CK_TILE_HOST_DEVICE static constexpr bool IsLinearTransform() { return true; }
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpperLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("insert{");
//
print(up_lengths_);
printf("}");
}
};
// replicate the original tensor and create a higher dimensional tensor
template <typename UpLengths>
struct replicate : public base_transform<0, UpLengths::size()>
{
static constexpr index_t NDimUp = UpLengths::size();
CK_TILE_HOST_DEVICE constexpr replicate() = default;
CK_TILE_HOST_DEVICE constexpr replicate(const UpLengths& up_lengths) : up_lengths_{up_lengths}
{
}
CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
{
static_assert(LowIdx::size() == 0 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void
update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
{
static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 0 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("replicate{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
//
UpLengths up_lengths_;
};
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct slice : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
UpLengths up_lengths_;
SliceBegin slice_begin_;
SliceEnd slice_end_;
CK_TILE_HOST_DEVICE constexpr slice() = default;
CK_TILE_HOST_DEVICE constexpr slice(const LowLength&,
const SliceBegin& slice_begin,
const SliceEnd& slice_end)
: up_lengths_{make_tuple(slice_end - slice_begin)},
slice_begin_{slice_begin},
slice_end_{slice_end}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] + slice_begin_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx&) const
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
ck_tile::is_known_at_compile_time<SliceEnd>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("slice{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
print(slice_end_);
printf("}");
} // namespace ck
}; // namespace ck
/*
* \brief lower_idx = upper_idx % modulus.
* TODO: Need an improved implementation since the modulo operation is expensive.
*/
template <typename Modulus, typename UpLength>
struct modulo : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
Modulus modulus_;
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr modulo() = default;
CK_TILE_HOST_DEVICE constexpr modulo(const Modulus& modulus, const UpLength& up_length)
: modulus_{modulus}, up_lengths_{make_tuple(up_length)}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] % modulus_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& up_idx) const
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
const auto idx_low_old = idx_low;
idx_low[I0] = (up_idx[I0] + idx_diff_up[I0]) % modulus_;
idx_diff_low[I0] = idx_low - idx_low_old;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Modulus{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
};
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths, typename RightShift>
struct xor_t : public base_transform<2, 2>
{
static constexpr auto type_enum = coord_transform_enum::xor_t;
using LowerIndex = multi_index<2>;
using UpperIndex = multi_index<2>;
using UpLengths = LowLengths;
UpLengths up_lengths_;
RightShift right_shift_;
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths,
const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::xor_t;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 2 && UpIdx::size() == 2,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}];
const auto idx_low_1_tmp =
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}];
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
idx_low(number<1>{}) = idx_low_1;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdxDiff::size() == 2 && UpIdxDiff::size() == 2 && LowIdx::size() == 2 &&
UpIdx::size() == 2,
"wrong! inconsistent # of dimension");
const auto idx_low_old = idx_low;
calculate_lower_index(idx_low, idx_up);
idx_diff_low = idx_low - idx_low_old;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<RightShift>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(
const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides) const
{
array<index_t, 2> up_vector_lengths = low_vector_lengths;
array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("xor_t{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}");
}
};
template <typename LowLength, typename OffsetLength>
struct offset : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{}));
UpLengths up_lengths_;
OffsetLength offset_length_;
CK_TILE_HOST_DEVICE constexpr offset() = default;
CK_TILE_HOST_DEVICE constexpr offset(const LowLength& low_length,
const OffsetLength& offset_length)
: up_lengths_{make_tuple(low_length)}, offset_length_{offset_length}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::offset;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] + offset_length_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx&) const
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<OffsetLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("offset{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("offset_length_: ");
print(offset_length_);
printf("}");
}
};
//*******************************************************************************************************
template <typename LowLength>
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength& low_length)
{
return pass_through<LowLength>{low_length};
}
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
CK_TILE_HOST_DEVICE constexpr auto
make_pad_transform(const LowLength& low_length,
const LeftPad& left_pad,
const RightPad& right_pad,
bool_constant<SkipIsValidCheck> = bool_constant<false>{})
{
return pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
}
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
CK_TILE_HOST_DEVICE constexpr auto
make_left_pad_transform(const LowLength& low_length,
const LeftPadLength& left_pad_,
bool_constant<SkipIsValidCheck> = bool_constant<false>{})
{
return left_pad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad_};
}
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
CK_TILE_HOST_DEVICE constexpr auto
make_right_pad_transform(const LowLength& low_length,
const RightPadLength& right_pad_,
bool_constant<SkipIsValidCheck> = bool_constant<false>{})
{
return right_pad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad_};
}
template <typename UpLengths,
typename Coefficients,
typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths& up_lengths,
const Coefficients& coefficients)
{
return embed<UpLengths, Coefficients>{up_lengths, coefficients};
}
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{
return merge_v2_magic_division<LowLengths>{low_lengths};
}
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto
make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
{
return merge_v3_division_mod<LowLengths>{low_lengths};
}
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
return make_merge_transform_v2_magic_division(low_lengths);
}
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
CK_TILE_HOST_DEVICE constexpr auto
make_unmerge_transform(const UpLengths& up_lengths,
bool_constant<Use24BitIntegerCalculation> = bool_constant<false>{})
{
return unmerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}
template <typename LowerIndex>
CK_TILE_HOST_DEVICE constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{
return freeze<LowerIndex>{low_idx};
}
template <typename UpperIndex>
CK_TILE_HOST_DEVICE constexpr auto make_insert_transform(const UpperIndex& up_idx)
{
return insert<UpperIndex>{up_idx};
}
template <typename UpLengths>
CK_TILE_HOST_DEVICE constexpr auto make_replicate_transform(const UpLengths& up_lengths)
{
return replicate<UpLengths>{up_lengths};
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
CK_TILE_HOST_DEVICE constexpr auto make_slice_transform(const LowLength& low_length,
const SliceBegin& slice_begin,
const SliceEnd& slice_end)
{
return slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
}
template <typename Modulus, typename UpLength>
CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
const UpLength& up_length)
{
return modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths, typename RightShift>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
{
return xor_t<LowLengths, RightShift>{low_lengths, right_shift};
}
template <typename LowLength, typename OffsetLength>
CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_length,
const OffsetLength& offset_length)
{
return offset<LowLength, OffsetLength>{low_length, offset_length};
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename TensorLengths,
typename DimAccessOrder,
typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct space_filling_curve
{
static constexpr index_t TensorSize =
reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
static_assert(0 < TensorSize,
"space_filling_curve should be used to access a non-empty tensor");
static constexpr index_t nDim = TensorLengths::size();
using Index = multi_index<nDim>;
static constexpr index_t ScalarPerVector =
reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
static constexpr auto dim_access_order = DimAccessOrder{};
static constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(ordered_access_lengths)),
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
make_tuple(sequence<0>{}));
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
static_assert(TensorLengths::size() == ScalarsPerAccess::size());
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
}
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number<AccessIdx1dHead>,
number<AccessIdx1dTail>)
{
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
"1D index out of range");
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
"1D index out of range");
constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
return idx_tail - idx_head;
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
{
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
{
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>)
{
#if 0
/*
* \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
*/
constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
#else
constexpr auto access_strides =
container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{});
constexpr auto idx_1d = number<AccessIdx1d>{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
{
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
{
auto res = idx_1d.value;
auto id = 0;
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
id = res / access_strides[kdim].value;
res -= id * access_strides[kdim].value;
});
return id;
};
constexpr auto id = compute_index_impl(idim);
return number<id>{};
};
constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
#endif
constexpr auto forward_sweep = [&]() {
statically_indexed_array<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto idim) {
index_t tmp = ordered_access_idx[I0];
static_for<1, idim, 1>{}(
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
forward_sweep_(idim) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate multi-dim tensor index
auto idx_md = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) =
!SnakeCurved || forward_sweep[idim]
? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
ScalarsPerAccess{};
}();
return idx_md;
}
// FIXME: rename this function
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
{
constexpr auto idx = get_index(number<AccessIdx1d>{});
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
struct __attribute__((packed)) buffer_resource
{
const void* ptr;
uint32_t range;
uint32_t config;
};
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff)
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
return __builtin_bit_cast(int32x4_t, res);
}
// TODO: glc/slc/...
template <index_t bytes>
struct buffer_load;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx))
template <>
struct buffer_load<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = float;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <index_t bytes>
struct buffer_load_if;
template <>
struct buffer_load_if<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t;
static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x2_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
template <index_t bytes>
struct buffer_store;
template <>
struct buffer_store<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
asm volatile(
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_store<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
asm volatile(
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_store<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile(
"buffer_store_dword %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_store<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 2);
using mbuf_t = short;
asm volatile(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_store<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile(
"buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <index_t bytes>
struct buffer_store_if;
template <>
struct buffer_store_if<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 16);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <>
struct buffer_store_if<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 8);
auto save_exec = __builtin_amdgcn_read_exec();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <>
struct buffer_store_if<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <>
struct buffer_store_if<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <>
struct buffer_store_if<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// clang-format off
namespace impl{
// can't use "+v" since there could be potential extra move(read/write)
// use "v" can help remove such duplicated moves
// besides, fake this as "memory" operation to force later valu after this fence
// TODO: may have scratch (because this is memory?)
// need to reduce extra move inside compiler
template<index_t N>
CK_TILE_DEVICE void insert_dummy_dep_per_dword(array<float, N>& b)
{
static_for<0, b.size(), 1>{}([&](auto i){
asm volatile(" " : : "v"(b.get(i)) : "memory");
});
}
#if 1
// below specialization just merge size() of dwords into single section
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(array<float, 2>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(array<float, 3>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(array<float, 4>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(array<float, 8>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
"v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(array<float, 16>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
"v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})),
"v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})),
"v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(array<float, 32>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
"v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})),
"v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})),
"v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})),
"v"(b.get(number<16>{})), "v"(b.get(number<17>{})), "v"(b.get(number<18>{})), "v"(b.get(number<19>{})),
"v"(b.get(number<20>{})), "v"(b.get(number<21>{})), "v"(b.get(number<22>{})), "v"(b.get(number<23>{})),
"v"(b.get(number<24>{})), "v"(b.get(number<25>{})), "v"(b.get(number<26>{})), "v"(b.get(number<27>{})),
"v"(b.get(number<28>{})), "v"(b.get(number<29>{})), "v"(b.get(number<30>{})), "v"(b.get(number<31>{})) : "memory");
}
#endif
CK_TILE_DEVICE void insert_dummy_dep() {}
template<typename T>
CK_TILE_DEVICE void insert_dummy_dep(T & buffer)
{
// TODO: indeed we expect T to be multiple of dword. subdword is always buggy
using da_type = array<float, (sizeof(T) + 3) / 4>;
auto & dummy = reinterpret_cast<da_type&>(buffer);
insert_dummy_dep_per_dword(dummy);
}
template<typename Tx, typename... Ty>
CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by)
{
insert_dummy_dep(bx);
insert_dummy_dep(by...);
}
}
// clang-format on
template <typename... T>
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0, T&... o)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
impl::insert_dummy_dep(o...);
}
CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
CK_TILE_DEVICE_EXTERN int8x2_t
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
CK_TILE_DEVICE_EXTERN int8x4_t
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
// buffer load i16
CK_TILE_DEVICE_EXTERN int16_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
CK_TILE_DEVICE_EXTERN int16x2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
CK_TILE_DEVICE_EXTERN int16x4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
// buffer load i32
CK_TILE_DEVICE_EXTERN int32_t
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
CK_TILE_DEVICE_EXTERN int32x2_t
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
CK_TILE_DEVICE_EXTERN int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// buffer load fp16
CK_TILE_DEVICE_EXTERN _Float16
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
CK_TILE_DEVICE_EXTERN fp16x2_t
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
CK_TILE_DEVICE_EXTERN fp16x4_t
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// buffer load fp32
CK_TILE_DEVICE_EXTERN float
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
CK_TILE_DEVICE_EXTERN fp32x2_t
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
CK_TILE_DEVICE_EXTERN fp32x4_t
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// buffer store i8
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
// buffer store i16
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
// buffer store i32
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// buffer store fp16
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// buffer store fp32
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// buffer atomic-add fp16
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
// buffer atomic-add fp32
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
// buffer atomic-max fp64
CK_TILE_DEVICE_EXTERN double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
CK_TILE_DEVICE void async_buffer_load_dword(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset)
: "memory");
}
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// memory coherency bit for buffer store/load instruction
// check ISA manual for each GFX target
// e.g. for
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
// page 67~68
enum struct amd_buffer_coherence_enum
{
coherence_default = 0, // default value
glc = 1,
slc = 2,
glc_slc = 3,
};
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE thread_buffer<int8_t, N>
amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
using rtn_type = thread_buffer<int8_t, N>;
if constexpr(N == 1)
{
return bit_cast<rtn_type>(llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 2)
{
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 4)
{
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 8)
{
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 16)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 32)
{
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
int32x4_t tmp1 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
thread_buffer<int32_t, 8> tmp;
tmp.template get_as<int32x4_t>()(number<0>{}) = tmp0;
tmp.template get_as<int32x4_t>()(number<1>{}) = tmp1;
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 64)
{
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
int32x4_t tmp1 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
int32x4_t tmp2 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int32_t),
static_cast<index_t>(coherence));
int32x4_t tmp3 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int32_t),
static_cast<index_t>(coherence));
thread_buffer<int32_t, 16> tmp;
tmp.template get_as<int32x4_t>()(number<0>{}) = tmp0;
tmp.template get_as<int32x4_t>()(number<1>{}) = tmp1;
tmp.template get_as<int32x4_t>()(number<2>{}) = tmp2;
tmp.template get_as<int32x4_t>()(number<3>{}) = tmp3;
return bit_cast<rtn_type>(tmp);
}
}
#ifndef BUFFER_LOAD_USE_INLINEASM
#define BUFFER_LOAD_USE_INLINEASM 0
#endif
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;
if constexpr(std::is_same<T, float>::value) // fp32
{
if constexpr(N == 1)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 2)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 4)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 8)
{
thread_buffer<float, 8> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
return tmp;
}
else if constexpr(N == 16)
{
thread_buffer<float, 16> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<2>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(float),
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<3>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(float),
static_cast<index_t>(coherence));
return tmp;
}
}
else if constexpr(std::is_same<T, fp16_t>::value) // fp16
{
if constexpr(N == 1)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 2)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 4)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 8)
{
// use fp32 load to mimic fp16 load
fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
}
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
{
if constexpr(N == 1)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 2)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 4)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 8)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
}
else // other datatype
{
auto raw_data = amd_buffer_load_impl_with_bytes<sizeof(T) * N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
return bit_cast<rtn_type>(raw_data);
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t flag = 0)
{
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
"wrong! not supported by buffer_load instruction");
using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check)
{
buffer_load_if<sizeof(type)>{}(
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
}
else
{
buffer_load<sizeof(type)>{}(
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0)
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
async_buffer_load_dword(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset);
}
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i8(bit_cast<int8_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 16)
{
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 32)
{
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
}
else if constexpr(N == 64)
{
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 8,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 12,
static_cast<index_t>(coherence));
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
if constexpr(std::is_same<T, float>::value) // fp32
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32(bit_cast<float>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast<fp32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<fp32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_fp32x4(
src_thread_data.template get_as<fp32x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp32x4(
src_thread_data.template get_as<fp32x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
}
}
else if constexpr(std::is_same<T, fp16_t>::value) // fp16
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(bit_cast<fp16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(bit_cast<fp16x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
#if 0
thread_buffer<fp16_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(fp16_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<fp32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
}
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<int16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<int16x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<int16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<int16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bf16_t),
static_cast<index_t>(coherence));
}
}
else
{
using r_t = thread_buffer<int8_t, sizeof(T) * N>;
amd_buffer_store_impl_with_bytes<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset);
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset,
index_t is_valid_element = 1)
{
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
"wrong! not supported by buffer_store instruction");
using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check)
{
buffer_store_if<sizeof(type)>{}(dst_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0,
is_valid_element);
}
else
{
buffer_store<sizeof(type)>{}(dst_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
if constexpr(std::is_same<T, float>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32(bit_cast<float>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(float),
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(float),
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(float),
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(float),
0);
}
}
else if constexpr(std::is_same<T, fp16_t>::value)
{
if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
static_for<0, 2, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
src_thread_data.template get_as<fp16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(fp16x2_t),
0);
});
}
else if constexpr(N == 8)
{
static_for<0, 4, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
src_thread_data.template get_as<fp16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(fp16x2_t),
0);
});
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_add_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t),
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t),
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(int32_t),
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(int32_t),
0);
}
}
}
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer<T, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert((std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
if constexpr(std::is_same<T, double>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64(bit_cast<double>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(double),
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(double),
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(double),
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(double),
0);
}
}
}
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
// oob_conditional_check : dynamic check if out-of-bound
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE thread_buffer<T, N>
amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = [&]() {
if constexpr(oob_conditional_check)
return src_thread_element_valid ? 0 : 0x80000000;
else
return 0;
}();
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : thread_buffer<T, N>{numeric<T>::zero()};
else
return tmp;
#endif
}
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE thread_buffer<T, N>
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size,
T customized_value)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : thread_buffer<T, N>{customized_value};
else
return tmp;
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size,
index_t is_valid_element = 0)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check>(
dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element);
}
// unfortunately async copy can not make sure invalid data is zero inside LDS
// ... unless people manually write zero to LDS at the proper address.
// so not support invalid_element check for now.
// buffer_load OOB still working.
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
}
// buffer_store requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_store(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = [&]() {
if constexpr(oob_conditional_check)
return dst_thread_element_valid ? 0 : 0x80000000;
else
return 0;
}();
amd_buffer_store_impl<T, N, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if constexpr(oob_conditional_check)
{
if(dst_thread_element_valid)
{
amd_buffer_store_impl<T, N, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
}
else
{
amd_buffer_store_impl<T, N, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_thread_element_valid);
}
// buffer_atomic_add requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_max_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_max_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,
T* lds_base_ptr,
const index_t lds_offset,
const bool is_valid,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource =
make_wave_buffer_resource(global_ptr, src_element_space_size * sizeof(T));
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset;
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes),
"s"(src_resource));
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
namespace ck_tile {
enum struct address_space_enum
{
generic,
global,
lds,
sgpr,
vgpr,
};
enum struct memory_operation_enum
{
set,
atomic_add,
atomic_max,
add
};
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
// warpSize is defined by HIP
return warpSize;
}
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
// TODO: deprecate these
CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
// Use these instead
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
CK_TILE_DEVICE index_t get_warp_id()
{
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
}
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
#else
__syncthreads();
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
}
CK_TILE_DEVICE void s_nop()
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
#else
__builtin_amdgcn_sched_barrier(0);
#endif
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <stdint.h>
namespace ck_tile {
// TODO: we have "memory" clobber here because this inline asm is used for async copy
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
{
asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory");
}
// NOTE: this is an immediate value
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
{
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_up(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_down(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#endif
#ifdef __HIPCC__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
#endif
#define CK_TILE_FLOAT_TO_FP8_STANDARD 0
#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1
#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT
#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD
#endif
// in the old rocm period, we have to use tuple array implementation to implement this
// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default.
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
#endif
#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0
#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1
#ifndef CK_TILE_THREAD_BUFFER_DEFAULT
#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY
#endif
#ifndef CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
// if using tuple-array as thread_buffer implementation, need to support {} brace init
// ... with similiar behavior as array
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 1
#else
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 0
#endif
#endif
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
#define CK_TILE_USE_LAUNCH_BOUNDS 1
#endif
#ifndef CK_TILE_TIME_KERNEL
#define CK_TILE_TIME_KERNEL 1
#endif
#define CK_TILE_MAX_THREAD_PER_BLOCK 256
#define CK_TILE_MIN_BLOCK_PER_CU 2
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
#define CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
#endif
#ifndef CK_TILE_USE_AMD_BUFFER_LOAD
#define CK_TILE_USE_AMD_BUFFER_LOAD 1
#endif
#ifndef CK_TILE_USE_AMD_BUFFER_STORE
#define CK_TILE_USE_AMD_BUFFER_STORE 1
#endif
#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#define CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
#endif
#ifndef CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <initializer_list>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// use aggregate initialization for this type
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
// use make_array_with({...}) to construct an array with compatible behavior as old ck
// TODO: manually added constructor same as old ck
template <typename T_, index_t N_>
struct array
{
using value_type = T_;
static constexpr index_t N = N_;
// TODO: do we need this?
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
// union {
value_type data[N];
// bulk_type __content;
//};
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
// TODO: will initialize the data[] with the last value repeatedly
// behavior different from std
CK_TILE_HOST_DEVICE constexpr array(std::initializer_list<value_type> ilist)
{
constexpr index_t list_size = std::initializer_list<value_type>{}.size();
static_assert(list_size <= N, "out of bound");
index_t i = 0;
value_type vlast = value_type{};
for(const value_type& val : ilist)
{
data[i] = val;
vlast = val;
++i;
}
for(; i < N; ++i)
{
data[i] = vlast;
}
}
template <typename Y,
typename = std::enable_if_t<std::is_convertible_v<Y, value_type> ||
std::is_constructible_v<Y, value_type>>>
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
{
for(auto i = 0; i < size(); i++)
data[i] = static_cast<value_type>(c);
}
// template <typename Y>
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// }
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// return *this;
// }
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
// clang-format off
CK_TILE_HOST_DEVICE constexpr auto& get() { return data; }
CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; }
CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get(number<I>) { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get(number<I>) const { return data[I]; }
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return get(i); }
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return get(i); } // TODO: compatible
#if 0
template <typename ArrayLike>
CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayLike& arr)
{
static_assert(ArrayLike::size() == size(), "wrong! size not the same");
for(index_t i = 0; i < size(); ++i)
{
data[i] = arr[i];
}
return *this;
}
#endif
// type punning (strict aliasing) member functions for read/write
// aliasing this array of type "T", "N" elements
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
#define AR_AS_COM_() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as()
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as() const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data); }
// below index is for index *AFTER* type convert, not before
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(i); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr auto& get_as(number<I>)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get_as(number<I>) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
#undef AR_AS_COM_
// clang-format on
};
// empty Array
template <typename T>
struct array<T, 0>
{
using value_type = T;
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<array<T, N>>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
namespace details {
template <class>
struct is_ref_wrapper : std::false_type
{
};
template <class T>
struct is_ref_wrapper<std::reference_wrapper<T>> : std::true_type
{
};
template <class T>
using not_ref_wrapper = std::negation<is_ref_wrapper<std::decay_t<T>>>;
template <class D, class...>
struct return_type_helper
{
using type = D;
};
template <class... Ts>
struct return_type_helper<void, Ts...> : std::common_type<Ts...>
{
static_assert(std::conjunction_v<not_ref_wrapper<Ts>...>,
"Ts cannot contain reference_wrappers when D is void");
};
template <class D, class... Ts>
using return_type = array<typename return_type_helper<D, Ts...>::type, sizeof...(Ts)>;
} // namespace details
template <typename D = void, typename... Ts>
CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&... ts)
{
return {std::forward<Ts>(ts)...};
}
// // make empty array
// template <typename T>
// CK_TILE_HOST_DEVICE constexpr auto make_array()
// {
// return array<T, 0>{};
// }
// compatible with old ck's initializer, make an array and fill it withe the last element from
// initializer_list
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list<T> ilist)
{
return array<T, Size>(ilist);
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator==(const array<T, Size>& a, const array<T, Size>& b)
{
bool same = true;
for(index_t i = 0; i < Size; ++i)
{
if(a[i] != b[i])
{
same = false;
break;
}
}
return same;
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const array<T, Size>& b)
{
return !(a == b);
}
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
{
static_assert(N <= X::size(), "");
array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
{
array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r[number<NSize>{}] = x;
return r;
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
{
return container_concat(make_tuple(x), a);
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
{
return container_concat(a, make_tuple(x));
}
// reorder array
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
{
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
}
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array, sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder array
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array,
const map<index_t, index_t>& new2old)
{
array<TData, NSize> new_array;
for(const auto& [new_pos, old_pos] : new2old)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array,
const map<index_t, index_t>& old2new)
{
array<TData, NSize> new_array;
for(const auto& [old_pos, new_pos] : old2new)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
// reorder tuple
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<Ts...>& old_tuple,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_tuple(old_tuple[number<IRs>{}]...);
}
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple<Ts...>& old_tuple,
sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder sequence
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is...> /* old_seq */,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
}
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is...> old_seq,
sequence<IRs...> /* old2new */)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
return container_reorder_given_new2old(old_seq, new2old);
}
#if 0
// rocm-4.1 compiler would crash for recursive lambda
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto r_old) {
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
// recursively call f/fs
return fs(fs, i + number<IStep>{}, r_new);
}
else
{
return r_new;
}
};
// start recursion
return f(f, number<IBegin>{}, init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename Container,
typename Reduce,
typename ROld,
index_t I,
index_t IEnd,
index_t IStep>
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(
const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
{
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
return container_reduce_impl(
x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return r_new;
}
}
// rocm-4.1 compiler would crash for recursive lambda
// container reduce with initial value
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
if constexpr(IEnd > IBegin)
{
return container_reduce_impl(
x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return init;
}
}
#endif
template <typename TData, index_t NSize, typename Reduce>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const array<TData, NSize>& x, Reduce f, TData init)
{
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename TData, index_t NSize, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const array<TData, NSize>& x, Reduce f, Init init)
{
#if 0
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
y(i) = r;
r = f(r, x[i]);
});
y(number<0>{}) = r;
return y;
#else
array<TData, NSize> y;
TData r = init;
for(index_t i = NSize - 1; i > 0; --i)
{
y(i) = r;
r = f(r, x[i]);
}
y(0) = r;
return y;
#endif
}
template <index_t... Is, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const sequence<Is...>& seq, Reduce f, number<Init>)
{
return reverse_exclusive_scan_sequence(seq, f, number<Init>{});
}
#if 0
// rocm4.1 compiler would crash with recursive lambda
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return fs(fs, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
};
// start recursion
return f(f, number<NSize - 1>{}, make_tuple(init), init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(
const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
{
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
}
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
return container_reverse_exclusive_scan_impl(
x, reduce, number<NSize - 1>{}, make_tuple(init), init);
}
#endif
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
template <typename... Xs, typename Reduce, typename TData>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const tuple<Xs...>& x, Reduce f, TData init)
{
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
{
return container_concat(x, container_concat(ys...));
}
template <typename T, index_t NX, index_t NY>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const array<T, NX>& ax, const array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
}
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple<X...>& tx, const tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
}
template <typename Container>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
{
return x;
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array<T, N>& arr, sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_array<T>(arr[Is]...);
}
else
{
return array<T, 0>{};
}
}
template <typename... Ts, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple<Ts...>& tup, sequence<Is...>)
{
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_tuple(tup[number<Is>{}]...);
}
else
{
return tuple<>{};
}
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void
set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
for(index_t i = 0; i < picks.size(); ++i)
{
y(picks[i]) = x[i];
}
}
}
template <typename Y, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
{
static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
}
// return the index of first occurance in the sequence.
// return seq.size(), if not found
template <index_t... Is>
constexpr index_t container_find(sequence<Is...> seq, index_t value)
{
for(auto i = 0; i < seq.size(); i++)
{
if(seq[i] == value)
return i;
}
return seq.size();
}
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
{
using Seq = sequence<Is...>;
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Seq::at(i);
return number<tmp>{};
},
number<Seq::size()>{});
}
#if 0
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, a_size, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
#else
// constexpr index_t can't be captured "-Wunused-lambda-capture"
// TODO: this is ugly
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
#endif
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace ck_tile {
// naive map
template <typename key, typename data, index_t max_size = 128>
struct map
{
using pair_type = tuple<key, data>;
using impl_type = array<pair_type, max_size>;
impl_type impl_;
index_t size_;
struct iterator
{
impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); }
};
struct const_iterator
{
const impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr const_iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); }
};
CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {}
CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; }
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& k) const
{
for(index_t i = 0; i < size(); i++)
{
if(impl_[i].template at<0>() == k)
{
return i;
}
}
return size_;
}
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& k) const
{
return const_iterator{impl_, find_position(k)};
}
CK_TILE_HOST_DEVICE constexpr iterator find(const key& k)
{
return iterator{impl_, find_position(k)};
}
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& k) const
{
const auto it = find(k);
// FIXME
// assert(it.pos_ < size());
return impl_[it.pos_].template at<1>();
}
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& k)
{
auto it = find(k);
// if entry not found
if(it.pos_ == size())
{
impl_(it.pos_).template at<0>() = k;
size_++;
}
// FIXME
// assert(size_ <= max_size);
return impl_(it.pos_).template at<1>();
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator end() const
{
return const_iterator{impl_, size_};
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [k, d] : *this)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <cstddef>
namespace ck_tile {
// TODO: this structure is not intented to be used by user
template <index_t MaxSize>
struct meta_data_buffer
{
CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs)
: buffer_{}, size_{0}
{
push(x, xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr void push(const T& data)
{
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
auto tmp = bit_cast<array<std::byte, size>>(data);
for(int i = 0; i < size; i++)
{
buffer_(size_) = tmp[i];
size_++;
}
}
}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs)
{
push(x);
push(xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const
{
T data;
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
data = bit_cast<T>(tmp);
}
return data;
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
auto data = bit_cast<T>(tmp);
return data;
}
//
array<std::byte, MaxSize> buffer_;
index_t size_ = 0;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// Don't use tihs directly. This is for old CK's internal usage,
// in the future always use array instead
template <index_t N>
using multi_index = array<index_t, N>;
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs)
{
return make_array<index_t>(index_t{xs}...);
}
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
{
return unpack([](auto... xs) { return make_multi_index(xs...); },
typename uniform_sequence_gen<NSize, 0>::type{});
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x)
{
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; });
return r;
}
// multi_index = index_t * multi_index
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index<NSize>& x)
{
multi_index<NSize> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
return r;
}
// multi_index = multi_index * index_t
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& x, index_t a)
{
return a * x;
}
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment