Commit 55a89c74 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 0dacd895 dcedf363
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/amd_address_space.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Memory type, allowed members:
* - Generic,
* - Global,
* - LDS,
* - SGPR,
* - VGPR,
*/
using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template <typename Shape, typename Strides>
struct Layout;
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors, // params for Register memory
index_t ScalarPerVector // param for Register memory
>
struct Tensor;
template <typename FromType, typename ToType>
struct Slice
{
__host__ __device__ constexpr Slice() : from_(), to_() {}
__host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {}
template <typename T>
__host__ __device__ constexpr auto range(const T& dim) const
{
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
is_same_v<T, index_t>)
{
assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range");
if(to_ < 0)
{
return dim - from_ + to_ + 1;
}
else
{
// workaround if one end of the interval is index_t and the second one is Number
return static_cast<index_t>(to_) - static_cast<index_t>(from_);
}
}
else
{
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_),
"Invalid range");
if constexpr(to_ < 0)
{
return dim - from_ + to_ + Number<1>{};
}
else
{
return to_ - from_;
}
}
}
__host__ __device__ static constexpr bool IsSlice() { return true; }
const FromType from_;
const ToType to_;
};
template <typename T>
using is_slice = decltype(std::declval<T&>().IsSlice());
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
/// @endcond
/**
* \brief Make tensor function.
*
* \tparam MemoryType Type of memory.
* \param pointer Pointer to the memory.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType, typename ElementType, typename Shape, typename Strides>
constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
{
return Tensor<MemoryType, ElementType, Shape, Strides, 0 /*NumVectors*/, 0 /*ScalarPerVector*/>(
pointer, layout);
}
/**
* \brief Make SGPR or VGPR tensor function.
*
* \tparam MemoryType Type of memory.
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType,
index_t NumVectors,
index_t ScalarPerVector,
typename ElementType,
typename Shape,
typename Strides>
constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
{
static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported");
return Tensor<MemoryType, ElementType, Shape, Strides, NumVectors, ScalarPerVector>(layout);
}
/**
* \brief Get Tensor Layout.
*
* \param tensor Tensor to get layout of.
* \return Requsted layout.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
layout(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return tensor.GetLayout();
}
/**
* \brief Product of tensor shape dims.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get Shape of.
* \return Requsted size.
*/
template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
size(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return size<Idxs...>(tensor.GetLayout());
}
/**
* \brief Rank of Shape tuple.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get rank of.
* \return Requsted rank.
*/
template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
rank(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return rank<Idxs...>(tensor.GetLayout());
}
/**
* \brief Depth of Shape tuple.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get depth of.
* \return Requsted depth.
*/
template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
depth(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return depth<Idxs...>(tensor.GetLayout());
}
/**
* \brief Get Tensor strides.
*
* \param tensor Tensor to get strides from.
* \return Requsted strides.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return stride(tensor.GetLayout());
}
/**
* \brief Get Tensor shape.
*
* \param tensor Tensor to get shape from.
* \return Requsted shape.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
shape(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return shape(tensor.GetLayout());
}
/**
* \brief Get dim slice.
*
* \param from Beginning of the interval.
* \param to End of the interval. (could be also negative to index from the end)
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
template <typename FromType, typename ToType>
constexpr auto slice(const FromType from, const ToType to)
{
return Slice<FromType, ToType>(from, to);
}
/**
* \brief Get dim slice. (Assumed that from is equal to 1)
*
* \param to End of the interval. (could be also negative to index from the end)
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
template <typename ToType>
constexpr auto slice(const ToType to)
{
if constexpr(is_same_v<ToType, index_t>)
{
return Slice<index_t, ToType>(0, to);
}
else
{
return Slice<Number<0>, ToType>(Number<0>{}, to);
}
}
/**
* \brief Get whole dim slice (from = 0, to = -1).
*
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
constexpr auto slice() { return Slice<Number<0>, Number<-1>>(Number<0>{}, Number<-1>{}); }
} // namespace wrapper
} // namespace ck
......@@ -86,9 +86,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
//
using GK = ck::tensor_layout::convolution::G_K;
using GK_Tuple = ck::Tuple<GK>;
using GK_GK_Tuple = ck::Tuple<GK, GK>;
using G_K = ck::tensor_layout::convolution::G_K;
using GK_Tuple = ck::Tuple<G_K>;
using GK_GK_Tuple = ck::Tuple<G_K, G_K>;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......
......@@ -61,7 +61,11 @@ using device_contraction_kk_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType>
// clang-format on
>;
......@@ -96,7 +100,11 @@ using device_contraction_kn_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType>
// clang-format on
>;
......@@ -131,7 +139,11 @@ using device_contraction_mk_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType>
// clang-format on
>;
......@@ -166,7 +178,11 @@ using device_contraction_mn_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType>
// clang-format on
>;
......
......@@ -25,10 +25,6 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template <ck::tensor_operation::device::GemmSpecialization GemmSpec>
using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
......@@ -37,7 +33,7 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
// pipeline v1, 1 wave
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>,
......@@ -75,7 +71,8 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
#if 0
//CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>,
......@@ -98,17 +95,6 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances<MNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -345,7 +345,11 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
......@@ -575,7 +579,8 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
......
......@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
BF16,
BF16,
......@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
F16,
F16,
......@@ -59,7 +59,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
F32,
F32,
......@@ -75,7 +75,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
int8_t,
int8_t,
......@@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK> &&
DLayouts::Size() == 2 && is_same_v<tuple_element_t<0, DLayouts>, NDHWGK> &&
is_same_v<tuple_element_t<1, DLayouts>, G_K>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
......
......@@ -101,7 +101,8 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp)
......
......@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
......
......@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances<MNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
BF16,
BF16,
......@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
......@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
......@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwd1x1S1P0>{});
}
......
......@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
F16,
F16,
......@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
......@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
......@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwd1x1S1P0>{});
}
......
......@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
F32,
F32,
......@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
......@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
......@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwd1x1S1P0>{});
}
......
......@@ -12,7 +12,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
int8_t,
int8_t,
......@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
......@@ -35,7 +35,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
......@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>,
ck::Tuple<NDHWGK, G_K>,
NDHWGK,
ConvFwd1x1S1P0>{});
}
......
......@@ -22,13 +22,13 @@ using S = ck::Sequence<Is...>;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using GK = ck::tensor_layout::convolution::G_K;
using G_K = ck::tensor_layout::convolution::G_K;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using GK_Tuple = ck::Tuple<GK>;
using GK_GK_Tuple = ck::Tuple<GK, GK>;
using GK_Tuple = ck::Tuple<G_K>;
using GK_GK_Tuple = ck::Tuple<G_K, G_K>;
using I32_Tuple = ck::Tuple<int32_t>;
using F32_Tuple = ck::Tuple<float>;
using I32_F32_Tuple = ck::Tuple<int32_t, float>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_transpose_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct MatrixLayout
{
NCDHW, // 0
NCHWD, // 1
};
enum struct DataType
{
F32_F32_F32_F32_F32, // 0
F16_F16_F16_F16_F16, // 1
};
#define OP_NAME "transpose"
#define OP_DESC "Transpose"
int profile_transpose(int argc, char* argv[])
{
if(argc != 15)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
// printf("arg3: matrix layout (NCDHW -> NDCHW);\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: N, C, D, H, W\n");
exit(1);
}
const auto data_type = static_cast<DataType>(std::stoi(argv[2]));
// const auto layout = static_cast<MatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
std::vector<index_t> lengths = std::stoi(argv[7]);
/**const int N = std::stoi(argv[7]);
const int C = std::stoi(argv[8]);
const int D = std::stoi(argv[9]);
const int H = std::stoi(argv[10]);
const int W = std::stoi(argv[11]);**/
using F32 = float;
using F16 = ck::half_t;
auto profile = [&](auto a_type, auto b_type) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
bool pass = ck::profiler::profile_transpose_impl<ADataType, BDataType>(
do_verification, init_method, do_log, time_kernel, lengths);
return pass ? 0 : 1;
};
if(data_type == GemmDataType::F32_F32_F32_F32_F32)
{
return profile(F32{}, F32{});
}
else if(data_type == GemmDataType::F16_F16_F16_F16_F16)
{
return profile(F16{}, F16{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_transpose);
......@@ -149,6 +149,7 @@ add_subdirectory(batched_gemm_multi_d)
add_subdirectory(grouped_convnd_bwd_data)
add_subdirectory(conv_tensor_rearrange)
add_subdirectory(transpose)
add_subdirectory(wrapper)
if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
endif()
add_gtest_executable(test_layout test_layout.cpp)
target_link_libraries(test_layout PRIVATE utility)
add_gtest_executable(test_tensor test_tensor.cpp)
target_link_libraries(test_tensor PRIVATE utility)
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck/utility/common_header.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
class TestWrapperLayout : public ::testing::Test
{
protected:
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
template <typename Desc,
typename Desc1d,
typename LayoutRuntime,
typename LayoutCompiletime,
typename Idxs>
void Run(Desc& desc,
Desc1d& desc_1d,
LayoutRuntime& layout_runtime,
LayoutCompiletime& layout_compiletime,
const std::vector<Idxs>& idxs)
{
// 1d check
EXPECT_EQ(desc_1d.GetLength(I0), ck::wrapper::size(layout_runtime));
// Check layout compiletime and runtime result consistency
EXPECT_EQ(ck::wrapper::size(layout_runtime), ck::wrapper::size(layout_compiletime));
for(ck::index_t i = 0; i < desc_1d.GetLength(I0); i++)
{
const ck::index_t layout_runtime_offset_1d = layout_runtime(ck::make_tuple(i));
const ck::index_t layout_compiletime_offset_1d = layout_compiletime(ck::make_tuple(i));
const ck::index_t desc_offset_1d = desc_1d.CalculateOffset(ck::make_tuple(i));
EXPECT_EQ(layout_runtime_offset_1d, desc_offset_1d);
EXPECT_EQ(layout_compiletime_offset_1d, layout_runtime_offset_1d);
}
// size(layout)-d check, don't check if access is hierarchical
if constexpr(!IsNestedTuple(Idxs{}))
{
ck::static_for<0, Idxs::Size(), 1>{}([&](auto d) {
EXPECT_EQ(desc.GetLength(ck::Number<d>{}), ck::wrapper::size<d>(layout_runtime));
EXPECT_EQ(ck::wrapper::size<d>(layout_runtime),
ck::wrapper::size<d>(layout_compiletime));
});
}
for(const auto idx : idxs)
{
const ck::index_t layout_runtime_offset = layout_runtime(idx);
const ck::index_t layout_compiletime_offset = layout_compiletime(idx);
const ck::index_t desc_offset =
desc.CalculateOffset(UnrollNestedTuple(idx)); // Unroll if nested
EXPECT_EQ(layout_runtime_offset, desc_offset);
EXPECT_EQ(layout_runtime_offset, layout_compiletime_offset);
}
}
};
TEST_F(TestWrapperLayout, 2d)
{
// dims:(4, 3) strides:(1, 4)
constexpr ck::index_t d1 = 4;
constexpr ck::index_t d0 = 3;
constexpr ck::index_t s1 = 1;
constexpr ck::index_t s0 = 4;
const auto desc =
ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{}),
ck::make_tuple(ck::Number<s1>{}, ck::Number<s0>{}));
// Reverse due to column major
const auto desc_1d = transform_tensor_descriptor(
desc,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1))),
ck::make_tuple(ck::Sequence<1, 0>{}),
ck::make_tuple(ck::Sequence<0>{}));
const auto layout_runtime = ck::wrapper::make_layout(ck::make_tuple(d1, d0));
const auto layout_compiletime =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{}));
std::vector<ck::Tuple<ck::index_t, ck::index_t>> idxs;
for(ck::index_t h = 0; h < d1; h++)
{
for(ck::index_t w = 0; w < d0; w++)
{
idxs.emplace_back(h, w);
}
}
this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs);
}
TEST_F(TestWrapperLayout, 3d_nested)
{
// dims:((2, 3), 4, 3) strides:((2, 4), 12, 48)
constexpr ck::index_t d3 = 2;
constexpr ck::index_t d2 = 3;
constexpr ck::index_t d1 = 4;
constexpr ck::index_t d0 = 3;
constexpr ck::index_t s3 = 2;
constexpr ck::index_t s2 = 4;
constexpr ck::index_t s1 = 12;
constexpr ck::index_t s0 = 48;
const auto desc = ck::make_naive_tensor_descriptor(
ck::make_tuple(ck::Number<d3>{}, ck::Number<d2>{}, ck::Number<d1>{}, ck::Number<d0>{}),
ck::make_tuple(ck::Number<s3>{}, ck::Number<s2>{}, ck::Number<s1>{}, ck::Number<s0>{}));
// Reverse due to column major
const auto desc_1d = transform_tensor_descriptor(
desc,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))),
ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
ck::make_tuple(ck::Sequence<0>{}));
const auto desc_3d = transform_tensor_descriptor(
desc,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)),
ck::make_pass_through_transform(d1),
ck::make_pass_through_transform(d2)),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
const auto layout_runtime =
ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), d1, d0),
ck::make_tuple(ck::make_tuple(s3, s2), s1, s0));
const auto layout_compiletime = ck::wrapper::make_layout(
ck::make_tuple(
ck::make_tuple(ck::Number<d3>{}, ck::Number<d2>{}), ck::Number<d1>{}, ck::Number<d0>{}),
ck::make_tuple(ck::make_tuple(ck::Number<s3>{}, ck::Number<s2>{}),
ck::Number<s1>{},
ck::Number<s0>{}));
std::vector<ck::Tuple<ck::index_t, ck::index_t, ck::index_t>> idxs_3d;
for(ck::index_t d = 0; d < d2 * d3; d++)
{
for(ck::index_t h = 0; h < d1; h++)
{
for(ck::index_t w = 0; w < d0; w++)
{
idxs_3d.emplace_back(d, h, w);
}
}
}
this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d);
// Check also 4d iteration
std::vector<ck::Tuple<ck::Tuple<ck::index_t, ck::index_t>, ck::index_t, ck::index_t>> idxs_4d;
for(ck::index_t e = 0; e < d3; e++)
{
for(ck::index_t d = 0; d < d2; d++)
{
for(ck::index_t h = 0; h < d1; h++)
{
for(ck::index_t w = 0; w < d0; w++)
{
idxs_4d.emplace_back(ck::make_tuple(e, d), h, w);
}
}
}
}
this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d);
}
TEST_F(TestWrapperLayout, 2d_nested)
{
// dims:((2, 3), (4, 3)) strides:((2, 4), (48, 12))
constexpr ck::index_t d3 = 2;
constexpr ck::index_t d2 = 3;
constexpr ck::index_t d1 = 4;
constexpr ck::index_t d0 = 3;
constexpr ck::index_t s3 = 2;
constexpr ck::index_t s2 = 4;
constexpr ck::index_t s1 = 48;
constexpr ck::index_t s0 = 12;
const auto desc = ck::make_naive_tensor_descriptor(
ck::make_tuple(ck::Number<d3>{}, ck::Number<d2>{}, ck::Number<d1>{}, ck::Number<d0>{}),
ck::make_tuple(ck::Number<s3>{}, ck::Number<s2>{}, ck::Number<s1>{}, ck::Number<s0>{}));
// Reverse due to column major
const auto desc_1d = transform_tensor_descriptor(
desc,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))),
ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
ck::make_tuple(ck::Sequence<0>{}));
const auto desc_2d = transform_tensor_descriptor(
desc,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)),
ck::make_merge_transform(ck::make_tuple(d0, d1))),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto layout_runtime =
ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), ck::make_tuple(d1, d0)),
ck::make_tuple(ck::make_tuple(s3, s2), ck::make_tuple(s1, s0)));
const auto layout_compiletime = ck::wrapper::make_layout(
ck::make_tuple(ck::make_tuple(ck::Number<d3>{}, ck::Number<d2>{}),
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})),
ck::make_tuple(ck::make_tuple(ck::Number<s3>{}, ck::Number<s2>{}),
ck::make_tuple(ck::Number<s1>{}, ck::Number<s0>{})));
std::vector<ck::Tuple<ck::index_t, ck::index_t>> idxs_2d;
for(ck::index_t h = 0; h < d2 * d3; h++)
{
for(ck::index_t w = 0; w < d0 * d1; w++)
{
idxs_2d.emplace_back(h, w);
}
}
this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d);
// Check also 4d iteration
std::vector<ck::Tuple<ck::Tuple<ck::index_t, ck::index_t>, ck::Tuple<ck::index_t, ck::index_t>>>
idxs_4d;
for(ck::index_t e = 0; e < d3; e++)
{
for(ck::index_t d = 0; d < d2; d++)
{
for(ck::index_t h = 0; h < d1; h++)
{
for(ck::index_t w = 0; w < d0; w++)
{
idxs_4d.emplace_back(ck::make_tuple(e, d), ck::make_tuple(h, w));
}
}
}
}
this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d);
}
TEST_F(TestWrapperLayout, 3d_double_nested)
{
// dims:(((2, 2), 3), (4, 3)) strides:(((2, 4), 8), (96, 24))
constexpr ck::index_t d4 = 2;
constexpr ck::index_t d3 = 2;
constexpr ck::index_t d2 = 3;
constexpr ck::index_t d1 = 4;
constexpr ck::index_t d0 = 3;
constexpr ck::index_t s4 = 2;
constexpr ck::index_t s3 = 4;
constexpr ck::index_t s2 = 8;
constexpr ck::index_t s1 = 96;
constexpr ck::index_t s0 = 24;
const auto desc = ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<d4>{},
ck::Number<d3>{},
ck::Number<d2>{},
ck::Number<d1>{},
ck::Number<d0>{}),
ck::make_tuple(ck::Number<s4>{},
ck::Number<s3>{},
ck::Number<s2>{},
ck::Number<s1>{},
ck::Number<s0>{}));
// Reverse due to column major
const auto desc_1d = transform_tensor_descriptor(
desc,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3, d4))),
ck::make_tuple(ck::Sequence<4, 3, 2, 1, 0>{}),
ck::make_tuple(ck::Sequence<0>{}));
const auto desc_3d = transform_tensor_descriptor(
desc,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d3, d4)),
ck::make_pass_through_transform(d2),
ck::make_merge_transform(ck::make_tuple(d0, d1))),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<4, 3>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
const auto desc_2d = transform_tensor_descriptor(
desc_3d,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3 * d4)),
ck::make_pass_through_transform(d1 * d0)),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto layout_runtime = ck::wrapper::make_layout(
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)),
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, s3), s2), ck::make_tuple(s1, s0)));
const auto layout_compiletime = ck::wrapper::make_layout(
ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})),
ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<s3>{}), ck::Number<s2>{}),
ck::make_tuple(ck::Number<s1>{}, ck::Number<s0>{})));
std::vector<ck::Tuple<ck::index_t, ck::index_t>> idxs_2d;
for(ck::index_t h = 0; h < d2 * d3 * d4; h++)
{
for(ck::index_t w = 0; w < d0 * d1; w++)
{
idxs_2d.emplace_back(h, w);
}
}
this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d);
// Check also 3d iteration
std::vector<ck::Tuple<ck::Tuple<ck::index_t, ck::index_t>, ck::index_t>> idxs_3d;
for(ck::index_t d = 0; d < d3 * d4; d++)
{
for(ck::index_t h = 0; h < d2; h++)
{
for(ck::index_t w = 0; w < d1 * d0; w++)
{
idxs_3d.emplace_back(ck::make_tuple(d, h), w);
}
}
}
this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d);
// Check also 5d iteration
std::vector<ck::Tuple<ck::Tuple<ck::Tuple<ck::index_t, ck::index_t>, ck::index_t>,
ck::Tuple<ck::index_t, ck::index_t>>>
idxs_5d;
for(ck::index_t f = 0; f < d4; f++)
{
for(ck::index_t e = 0; e < d3; e++)
{
for(ck::index_t d = 0; d < d2; d++)
{
for(ck::index_t h = 0; h < d1; h++)
{
for(ck::index_t w = 0; w < d0; w++)
{
idxs_5d.emplace_back(ck::make_tuple(ck::make_tuple(f, e), d),
ck::make_tuple(h, w));
}
}
}
}
}
this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_5d);
}
TEST(TestLayoutHelpers, SizeAndGet)
{
// dims:(((2, 2), 3), (4, 3))
constexpr ck::index_t d4 = 2;
constexpr ck::index_t d3 = 2;
constexpr ck::index_t d2 = 3;
constexpr ck::index_t d1 = 4;
constexpr ck::index_t d0 = 3;
const auto layout_runtime = ck::wrapper::make_layout(
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)));
const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})));
// Size of layout
EXPECT_EQ(ck::wrapper::size(layout_runtime), d4 * d3 * d2 * d1 * d0);
EXPECT_EQ(ck::wrapper::size(layout_compiletime), d4 * d3 * d2 * d1 * d0);
// Size of dims
EXPECT_EQ(ck::wrapper::size<0>(layout_runtime), d4 * d3 * d2);
EXPECT_EQ(ck::wrapper::size<0>(layout_compiletime), d4 * d3 * d2);
EXPECT_EQ(ck::wrapper::size<1>(layout_runtime), d1 * d0);
EXPECT_EQ(ck::wrapper::size<1>(layout_compiletime), d1 * d0);
// Access through new layout (using get with layout object)
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_runtime)), d4 * d3);
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_compiletime)), d4 * d3);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2);
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d4);
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))),
d4);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d3);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))),
d3);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2);
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_runtime)), d1);
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_compiletime)), d1);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_runtime)), d0);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_compiletime)), d0);
}
TEST(TestLayoutHelpers, DepthAndRank)
{
// dims:(((2, 2), 3), (4, 3))
constexpr ck::index_t d4 = 2;
constexpr ck::index_t d3 = 2;
constexpr ck::index_t d2 = 3;
constexpr ck::index_t d1 = 4;
constexpr ck::index_t d0 = 3;
const auto layout_runtime = ck::wrapper::make_layout(
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)));
const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})));
EXPECT_EQ(ck::wrapper::depth(layout_runtime), 3);
EXPECT_EQ(ck::wrapper::depth(layout_compiletime), 3);
EXPECT_EQ(ck::wrapper::depth(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2);
// Check for integer
EXPECT_EQ(ck::wrapper::depth(d0), 0);
EXPECT_EQ(ck::wrapper::rank(layout_runtime), 2);
EXPECT_EQ(ck::wrapper::rank(layout_compiletime), 2);
EXPECT_EQ(ck::wrapper::rank(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2);
// Check for integer
EXPECT_EQ(ck::wrapper::rank(d0), 1);
}
TEST(TestLayoutHelpers, ShapeAndStrides)
{
// dims:(((2, 2), 3), (4, 3))
constexpr ck::index_t d4 = 2;
constexpr ck::index_t d3 = 2;
constexpr ck::index_t d2 = 3;
constexpr ck::index_t d1 = 4;
constexpr ck::index_t d0 = 3;
constexpr ck::index_t s4 = 2;
constexpr ck::index_t s3 = 4;
constexpr ck::index_t s2 = 8;
constexpr ck::index_t s1 = 96;
constexpr ck::index_t s0 = 24;
const auto shape_compiletime = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{}));
const auto strides_compiletime = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<s4>{}, ck::Number<s3>{}), ck::Number<s2>{}),
ck::make_tuple(ck::Number<s1>{}, ck::Number<s0>{}));
const auto shape_runtime =
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0));
const auto strides_runtime =
ck::make_tuple(ck::make_tuple(ck::make_tuple(s4, s3), s2), ck::make_tuple(s1, s0));
const auto layout_runtime = ck::wrapper::make_layout(shape_runtime, strides_runtime);
const auto layout_compiletime =
ck::wrapper::make_layout(shape_compiletime, strides_compiletime);
constexpr bool check_compiletime_shape =
std::is_same_v<decltype(shape_compiletime),
std::remove_reference_t<decltype(shape(layout_compiletime))>>;
constexpr bool check_compiletime_strides =
std::is_same_v<decltype(strides_compiletime),
std::remove_reference_t<decltype(stride(layout_compiletime))>>;
constexpr bool check_runtime_shape =
std::is_same_v<decltype(shape_runtime),
std::remove_reference_t<decltype(shape(layout_runtime))>>;
constexpr bool check_runtime_strides =
std::is_same_v<decltype(strides_runtime),
std::remove_reference_t<decltype(stride(layout_runtime))>>;
EXPECT_TRUE(check_compiletime_shape);
EXPECT_TRUE(check_compiletime_strides);
EXPECT_TRUE(check_runtime_shape);
EXPECT_TRUE(check_runtime_strides);
}
TEST(TestLayoutHelpers, Hierarchical)
{
// dims:(((2, 2), 3), (4, 3))
constexpr ck::index_t d4 = 2;
constexpr ck::index_t d3 = 2;
constexpr ck::index_t d2 = 3;
constexpr ck::index_t d1 = 4;
constexpr ck::index_t d0 = 3;
const auto runtime_shape =
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0));
const auto layout_runtime = ck::wrapper::make_layout(runtime_shape);
const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})));
EXPECT_EQ((ck::wrapper::rank<0, 0>(runtime_shape)), 2);
EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_runtime)), 2);
EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_compiletime)), 2);
EXPECT_EQ((ck::wrapper::depth<0, 0>(runtime_shape)), 1);
EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_runtime)), 1);
EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_compiletime)), 1);
EXPECT_EQ((ck::wrapper::size<0, 0>(runtime_shape)), d4 * d3);
EXPECT_EQ((ck::wrapper::size<0, 0>(layout_runtime)), d4 * d3);
EXPECT_EQ((ck::wrapper::size<0, 0>(layout_compiletime)), d4 * d3);
EXPECT_EQ((ck::wrapper::get<0, 0, 0>(runtime_shape)), d4);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment