Commit dacc5119 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Simplify permute bundle example

parent aef15d2e
...@@ -52,36 +52,6 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -52,36 +52,6 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
namespace detail { namespace detail {
template <typename Bundle, std::size_t Divisor>
struct get_bundled;
template <typename Bundle>
struct get_bundled<Bundle, 1>
{
using type = Bundle;
};
template <>
struct get_bundled<F64, 2>
{
using type = F32;
};
template <>
struct get_bundled<F64, 4>
{
using type = F16;
};
template <>
struct get_bundled<F32, 2>
{
using type = F16;
};
template <typename Bundle, std::size_t Divisor>
using get_bundled_t = typename get_bundled<Bundle, Divisor>::type;
template <typename Array, std::size_t Difference> template <typename Array, std::size_t Difference>
struct enlarge_array_size; struct enlarge_array_size;
......
...@@ -3,21 +3,20 @@ ...@@ -3,21 +3,20 @@
#include "common.hpp" #include "common.hpp"
using ADataType = F64; using DataType = F16;
using BDataType = F64; using BundleType = F64;
static_assert(sizeof(BundleType) % sizeof(DataType) == 0);
// clang-format off // clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
// ######| InData| OutData| Elementwise| NumDim| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| // ######| InData| OutData| Elementwise| NumDim| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| Type| Type| Operation| | Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| // ######| Type| Type| Operation| | Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
< ADataType, BDataType, PassThrough, 3, 256, 1, 32, 32, 5, S<1, 32, 8>, S<0, 1, 2>, 2, 1, 4, 1>; < BundleType, BundleType, PassThrough, 3, 256, 1, 32, 32, 5, S<1, 32, 8>, S<0, 1, 2>, 2, 1, 4, 1>;
// clang-format on // clang-format on
#define NUM_ELEMS_IN_BUNDLE 4
static_assert(std::is_same_v<detail::get_bundled_t<F64, NUM_ELEMS_IN_BUNDLE>, F16>);
#include "run_permute_bundle_example.inc" #include "run_permute_bundle_example.inc"
int main() { return !run_permute_bundle_example({1, 80, 16000}, {0, 2, 1}); } int main() { return !run_permute_bundle_example({1, 80, 16000}, {0, 2, 1}); }
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
bool run_permute_bundle(const Problem& problem) bool run_permute_bundle(const Problem& problem)
{ {
static_assert(std::is_same_v<ADataType, BDataType> && constexpr std::size_t NumElemsInBundle = sizeof(BundleType) / sizeof(DataType);
(sizeof(ADataType) % NUM_ELEMS_IN_BUNDLE == 0));
using std::begin, std::end; using std::begin, std::end;
...@@ -14,19 +13,19 @@ bool run_permute_bundle(const Problem& problem) ...@@ -14,19 +13,19 @@ bool run_permute_bundle(const Problem& problem)
ck::remove_cvref_t<decltype(shape)> transposed_shape; ck::remove_cvref_t<decltype(shape)> transposed_shape;
transpose_shape(problem.shape, problem.axes, begin(transposed_shape)); transpose_shape(problem.shape, problem.axes, begin(transposed_shape));
Tensor<ADataType> a(shape); Tensor<BundleType> a(shape);
Tensor<BDataType> b(transposed_shape); Tensor<BundleType> b(transposed_shape);
// initialize tensor by assigning DataType values
using std::data, std::size; using std::data, std::size;
{ {
auto* const elems = auto* const elems = reinterpret_cast<DataType*>(data(a.mData));
reinterpret_cast<detail::get_bundled_t<ADataType, NUM_ELEMS_IN_BUNDLE>*>(data(a.mData)); ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}( elems, elems + (size(a.mData) * NumElemsInBundle));
elems, elems + (size(a.mData) * NUM_ELEMS_IN_BUNDLE));
} }
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(BundleType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BundleType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(data(a.mData)); a_device_buf.ToDevice(data(a.mData));
...@@ -61,10 +60,8 @@ bool run_permute_bundle(const Problem& problem) ...@@ -61,10 +60,8 @@ bool run_permute_bundle(const Problem& problem)
b_device_buf.FromDevice(data(b.mData)); b_device_buf.FromDevice(data(b.mData));
// extend tensor shape from [N, H, W] to [N, H, W, NUM_ELEMS_IN_BUNDLE] // extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
using DataType = detail::get_bundled_t<ADataType, NUM_ELEMS_IN_BUNDLE>; const auto extended_shape = extend_shape(shape, NumElemsInBundle);
const auto extended_shape = extend_shape(shape, NUM_ELEMS_IN_BUNDLE);
const auto extended_axes = extend_axes(problem.axes); const auto extended_axes = extend_axes(problem.axes);
ck::remove_cvref_t<decltype(extended_shape)> transposed_extended_shape; ck::remove_cvref_t<decltype(extended_shape)> transposed_extended_shape;
...@@ -72,7 +69,7 @@ bool run_permute_bundle(const Problem& problem) ...@@ -72,7 +69,7 @@ bool run_permute_bundle(const Problem& problem)
Tensor<DataType> extended_a(extended_shape); Tensor<DataType> extended_a(extended_shape);
std::memcpy( std::memcpy(
data(extended_a.mData), data(a.mData), sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); data(extended_a.mData), data(a.mData), sizeof(BundleType) * a.mDesc.GetElementSpaceSize());
Tensor<DataType> extended_host_b(transposed_extended_shape); Tensor<DataType> extended_host_b(transposed_extended_shape);
if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b)) if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b))
...@@ -82,7 +79,7 @@ bool run_permute_bundle(const Problem& problem) ...@@ -82,7 +79,7 @@ bool run_permute_bundle(const Problem& problem)
return ck::utils::check_err( return ck::utils::check_err(
ck::span<const DataType>{reinterpret_cast<DataType*>(data(b.mData)), ck::span<const DataType>{reinterpret_cast<DataType*>(data(b.mData)),
b.mDesc.GetElementSpaceSize() * NUM_ELEMS_IN_BUNDLE}, b.mDesc.GetElementSpaceSize() * NumElemsInBundle},
ck::span<const DataType>{extended_host_b.mData}, ck::span<const DataType>{extended_host_b.mData},
"Error: incorrect results in output tensor", "Error: incorrect results in output tensor",
1e-6, 1e-6,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment