Unverified Commit fcbb9788 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Dynamic tensor descriptor (#24)



* support dynamic tensor descriptor

* use buffer load OOB feature for padding case

* add navi support

* add int8x4 inference kernel
Co-authored-by: default avatarChao Liu <chao@ixt-rack-81.local.lan>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent bbcb67d0
#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP
#define CK_STATICALLY_INDEXED_ARRAY_HPP
#include "functional2.hpp"
#include "sequence.hpp"
#include "tuple.hpp"
namespace ck {
namespace detail {
template <typename T, index_t NSize>
__host__ __device__ constexpr auto generate_same_type_tuple()
{
return generate_tuple([](auto) -> T { return T{}; }, Number<NSize>{});
}
template <typename T, index_t NSize>
using same_type_tuple = decltype(generate_same_type_tuple<T, NSize>());
} // namespace detail
template <typename T, index_t NSize>
using StaticallyIndexedArray = detail::same_type_tuple<T, NSize>;
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
// make empty StaticallyIndexedArray
template <typename X>
__host__ __device__ constexpr auto make_statically_indexed_array()
{
return StaticallyIndexedArray<X, 0>();
}
} // namespace ck
#endif
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_TUPLE_HPP #define CK_TUPLE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp"
#include "sequence.hpp" #include "sequence.hpp"
#include "type.hpp"
namespace ck { namespace ck {
...@@ -12,15 +12,19 @@ namespace detail { ...@@ -12,15 +12,19 @@ namespace detail {
template <index_t> template <index_t>
struct TupleElementKey struct TupleElementKey
{ {
__host__ __device__ constexpr TupleElementKey() = default;
}; };
template <typename Key, typename Data> template <typename Key, typename Data>
struct TupleElement struct TupleElement
{ {
__host__ __device__ explicit constexpr TupleElement() : mData() {} __host__ __device__ constexpr TupleElement() = default;
template <typename T> template <
__host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v)) typename T,
typename std::enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
{ {
} }
...@@ -30,7 +34,7 @@ struct TupleElement ...@@ -30,7 +34,7 @@ struct TupleElement
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x) __host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
{ {
return x.mData; return static_cast<const Data&>(x.mData);
} }
template <typename Key, typename Data> template <typename Key, typename Data>
...@@ -39,6 +43,7 @@ __host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x ...@@ -39,6 +43,7 @@ __host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x
return x.mData; return x.mData;
} }
// TODO: not sure the use of reference is correct
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x) __host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
{ {
...@@ -51,14 +56,24 @@ struct TupleImpl; ...@@ -51,14 +56,24 @@ struct TupleImpl;
template <index_t... Is, typename... Xs> template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>... struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
{ {
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()... __host__ __device__ constexpr TupleImpl() = default;
template <
typename Y,
typename std::enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
{ {
} }
template <typename... Ys> template <typename... Ys, typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{ {
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size");
} }
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
...@@ -84,11 +99,25 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -84,11 +99,25 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
using base = using base =
detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>; detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;
template <typename... Ys> __host__ __device__ constexpr Tuple() = default;
__host__ __device__ explicit constexpr Tuple(Ys&&... ys) : base(static_cast<Ys&&>(ys)...)
template <typename Y,
typename std::enable_if<
sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
{ {
} }
template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2,
bool>::type = false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
{
}
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I>) const __host__ __device__ constexpr const auto& At(Number<I>) const
{ {
...@@ -102,57 +131,34 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -102,57 +131,34 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
static_assert(I < base::Size(), "wrong! out of range"); static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementByKey(detail::TupleElementKey<I>{});
} }
};
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
namespace detail {
template <typename F, typename X, index_t... Is>
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is> template <index_t I>
__host__ __device__ constexpr auto __host__ __device__ constexpr const auto& operator[](Number<I> i) const
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>) {
{ return At(i);
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...); }
}
template <typename F, typename X, typename Y, typename Z, index_t... Is> template <index_t I>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto& operator()(Number<I> i)
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>) {
{ return At(i);
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...); }
}
} // namespace detail template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_assert(T::Size() == Size(), "wrong! size not the same");
template <typename F, typename X> static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y> return *this;
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y) }
{ };
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z> template <typename... Xs>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
return detail::transform_tuples_impl( return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
} }
} // namespace ck } // namespace ck
......
#ifndef CK_TUPLE_HELPER_HPP
#define CK_TUPLE_HELPER_HPP
#include "functional4.hpp"
#include "tuple.hpp"
namespace ck {
template <typename... Ts>
struct is_known_at_compile_time<Tuple<Ts...>>
{
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return container_reduce(
Tuple<Ts...>{},
[](auto x, bool r) {
return is_known_at_compile_time<
remove_cv_t<remove_reference_t<decltype(x)>>>::value &
r;
},
true);
}
static constexpr bool value = IsKnownAtCompileTime();
};
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
{
return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
namespace detail {
template <typename F, typename X, index_t... Is>
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, typename Z, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
}
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
} // namespace ck
#endif
...@@ -5,9 +5,6 @@ ...@@ -5,9 +5,6 @@
namespace ck { namespace ck {
template <index_t... Is>
struct Sequence;
template <typename X, typename Y> template <typename X, typename Y>
struct is_same : public integral_constant<bool, false> struct is_same : public integral_constant<bool, false>
{ {
...@@ -18,26 +15,32 @@ struct is_same<X, X> : public integral_constant<bool, true> ...@@ -18,26 +15,32 @@ struct is_same<X, X> : public integral_constant<bool, true>
{ {
}; };
template <typename> template <typename T>
struct is_static : integral_constant<bool, false> using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T>
constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
{ {
}; return static_cast<typename std::remove_reference<T>::type&&>(t);
}
template <typename T, T X> template <typename T>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true> struct is_known_at_compile_time;
template <>
struct is_known_at_compile_time<index_t>
{ {
static constexpr bool value = false;
}; };
template <index_t... Is> template <typename T, T X>
struct is_static<Sequence<Is...>> : integral_constant<bool, true> struct is_known_at_compile_time<integral_constant<T, X>>
{ {
static constexpr bool value = true;
}; };
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
} // namespace ck } // namespace ck
#endif #endif
...@@ -9,6 +9,6 @@ __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } ...@@ -9,6 +9,6 @@ __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
} // namspace ck } // namespace ck
#endif #endif
...@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor( ...@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
} }
template <class InDesc, class WeiDesc, class OutDesc> template <class InDesc, class WeiDesc, class OutDesc>
constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc) constexpr std::size_t
calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc)
{ {
using namespace ck; using namespace ck;
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr index_t N = out_desc.GetLength(I0); const index_t N = out_desc.GetLength(I0);
constexpr index_t K = out_desc.GetLength(I1); const index_t K = out_desc.GetLength(I1);
constexpr index_t Ho = out_desc.GetLength(I2); const index_t Ho = out_desc.GetLength(I2);
constexpr index_t Wo = out_desc.GetLength(I3); const index_t Wo = out_desc.GetLength(I3);
constexpr index_t C = wei_desc.GetLength(I1); const index_t C = wei_desc.GetLength(I1);
constexpr index_t Y = wei_desc.GetLength(I2); const index_t Y = wei_desc.GetLength(I2);
constexpr index_t X = wei_desc.GetLength(I3); const index_t X = wei_desc.GetLength(I3);
return std::size_t(2) * N * K * Ho * Wo * C * Y * X; return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
} }
......
...@@ -183,7 +183,7 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -183,7 +183,7 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmBBlockCopyDstDataPerWrite_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>; GemmCThreadCopyDstDataPerWrite_GemmN1>;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 1; ++i)
{ {
std::cout << "Start running " << nrepeat << " times..." << std::endl; std::cout << "Start running " << nrepeat << " times..." << std::endl;
......
...@@ -57,10 +57,41 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -57,10 +57,41 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// GemmABlockCopySrcDataPerRead_GemmM = 4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
...@@ -74,11 +105,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -74,11 +105,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
...@@ -104,11 +135,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -104,11 +135,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
......
...@@ -222,7 +222,7 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc i ...@@ -222,7 +222,7 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc i
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) { static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id); constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
constexpr index_t gemm_k2 = gemm_sizes.At(4); constexpr index_t gemm_k2 = gemm_sizes[Number<4>{}];
constexpr bool is_gemm_not_empty = gemm_k2 > 0; constexpr bool is_gemm_not_empty = gemm_k2 > 0;
// only compile and run if GEMM is no empty // only compile and run if GEMM is no empty
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp" #include "gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
template <typename T, template <typename T,
typename InDesc, typename InDesc,
...@@ -13,18 +13,20 @@ template <typename T, ...@@ -13,18 +13,20 @@ template <typename T,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename LeftPads,
typename RightPads> typename RightPads>
void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, void device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, LeftPads,
RightPads, RightPads,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
std::cout << "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw" << std::endl;
using namespace ck; using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type; using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
...@@ -133,7 +135,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -133,7 +135,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -770,45 +772,46 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -770,45 +772,46 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer< using gridwise_conv =
GridSize, GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer<
BlockSize, GridSize,
T, BlockSize,
T, T,
decltype(in_nchw_desc), T,
decltype(wei_kcyx_desc), decltype(in_nchw_desc),
decltype(out_nkhw_desc), decltype(wei_kcyx_desc),
ConvStrides, decltype(out_nkhw_desc),
ConvDilations, ConvStrides,
LeftPads, ConvDilations,
RightPads, LeftPads,
BPerBlock, RightPads,
KPerBlock, BPerBlock,
EPerBlock, KPerBlock,
GemmNRepeat, EPerBlock,
GemmMPerThread, GemmNRepeat,
GemmNPerThread, GemmMPerThread,
GemmKPerThread, GemmNPerThread,
GemmMLevel0Cluster, GemmKPerThread,
GemmNLevel0Cluster, GemmMLevel0Cluster,
GemmMLevel1Cluster, GemmNLevel0Cluster,
GemmNLevel1Cluster, GemmMLevel1Cluster,
GemmDataPerReadA, GemmNLevel1Cluster,
GemmDataPerReadB, GemmDataPerReadA,
InBlockCopySubLengths_E_N1_B_N2, GemmDataPerReadB,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopySrcAccessOrder, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyDstAccessOrder, InBlockCopySrcAccessOrder,
InBlockCopySrcDataPerRead_B, InBlockCopyDstAccessOrder,
InBlockCopyDstDataPerWrite_N2, InBlockCopySrcDataPerRead_B,
WeiBlockCopySubLengths_E_K, InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopySrcAccessOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstAccessOrder,
WeiBlockCopyDstDataPerWrite_K>; WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <class T, template <class T,
class InDesc, class InDesc,
...@@ -12,18 +12,20 @@ template <class T, ...@@ -12,18 +12,20 @@ template <class T,
class ConvDilations, class ConvDilations,
class InLeftPads, class InLeftPads,
class InRightPads> class InRightPads>
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
InLeftPads, InLeftPads,
InRightPads, InRightPads,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
std::cout << "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" << std::endl;
using namespace ck; using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type; using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
...@@ -55,6 +57,109 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -55,6 +57,109 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 0
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2;
#elif 0
// cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize = 256, 64x256x8 // cdata = 64, BlockSize = 256, 64x256x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -62,14 +167,14 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -62,14 +167,14 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 16; constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -86,6 +191,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -86,6 +191,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 0
// cdata = 64, BlockSize = 256, 128x128x4 // cdata = 64, BlockSize = 256, 128x128x4
...@@ -99,10 +237,10 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -99,10 +237,10 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -122,6 +260,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -122,6 +260,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
// b threadwise copy 4x1
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -152,6 +291,40 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -152,6 +291,40 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x8
// b threadwise copy 2x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 0
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
...@@ -255,7 +428,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -255,7 +428,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0 #elif 1
// cdata = 64, BlockSize = 256, 128x128x16 // cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4 // GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
...@@ -289,6 +462,41 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -289,6 +462,41 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0 #elif 0
// cdata = 64, BlockSize = 128, 128x64x4 // cdata = 64, BlockSize = 128, 128x64x4
...@@ -826,7 +1034,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -826,7 +1034,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1 #elif 0
// cdata = 64, BlockSize = 64, 64x64x3 // cdata = 64, BlockSize = 64, 64x64x3
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -968,7 +1176,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -968,7 +1176,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw< using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw<
GridSize, GridSize,
BlockSize, BlockSize,
TDevice, TDevice,
......
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
std::cout << "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" << std::endl;
using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto N = OutDesc::GetLengths()[I0];
constexpr auto K = OutDesc::GetLengths()[I1];
constexpr auto C = WeiDesc::GetLengths()[I1];
constexpr auto Hi = InDesc::GetLengths()[I2];
constexpr auto Wi = InDesc::GetLengths()[I3];
constexpr auto Ho = OutDesc::GetLengths()[I2];
constexpr auto Wo = OutDesc::GetLengths()[I3];
constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3];
// compile-time variables
constexpr auto in_n_hi_wi_c_desc =
make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{});
constexpr auto wei_k_y_x_c_desc = make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
constexpr auto out_n_ho_wo_k_desc =
make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{});
Tensor<float> in_nhwc(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
Tensor<float> wei_kyxc(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
Tensor<float> out_nhwk(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi);
};
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x);
};
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
};
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency());
std::size_t data_sz = sizeof(T);
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace());
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace());
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace());
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
#if 1
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmM1 = 2;
#endif
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk<
GridSize,
BlockSize,
TDevice,
TDevice,
decltype(in_n_hi_wi_c_desc),
decltype(wei_k_y_x_c_desc),
decltype(out_n_ho_wo_k_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmK,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmM1>;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv,
const TDevice* const __restrict__,
const TDevice* const __restrict__,
TDevice* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<TDevice*>(in_nhwc_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(wei_kyxc_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nhwk_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
out_nhwk_device_buf.FromDevice(out_nhwk.mData.data());
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k);
};
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency());
}
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <class TInWei,
ck::index_t InWeiVectorSize,
class TAcc,
class TOut,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
InDesc,
const Tensor<TInWei>& in_n_c_hi_wi,
WeiDesc,
const Tensor<TInWei>& wei_k_c_y_x,
OutDesc,
Tensor<TOut>& out_n_k_ho_wo,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
<< std::endl;
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
#if 0
// run-time variables
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{});
const auto in_left_pads = to_multi_index(InLeftPads{});
const auto in_right_pads = to_multi_index(InRightPads{});
#else
// compile-time variables
const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(WeiDesc::GetLengths()));
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(OutDesc::GetLengths()));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif
#if 0
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif 0
// cdata = 32, BlockSize 64, 16x128x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize 64, 16x256x2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1
// cdata = 64, BlockSize 64, 16x256x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0
// cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 4x1
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 2x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#endif
constexpr auto conv_driver =
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
#endif
<BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
TOut,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmN1>{};
conv_driver.Run(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
template <class TInWei,
ck::index_t InWeiVectorSize,
class TAcc,
class TOut,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
InDesc,
const Tensor<TInWei>& in_n_c_hi_wi,
WeiDesc,
const Tensor<TInWei>& wei_k_c_y_x,
OutDesc,
Tensor<TOut>& out_n_k_ho_wo,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk"
<< std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto N = OutDesc::GetLengths()[I0];
constexpr auto K = OutDesc::GetLengths()[I1];
constexpr auto C = WeiDesc::GetLengths()[I1];
constexpr auto Hi = InDesc::GetLengths()[I2];
constexpr auto Wi = InDesc::GetLengths()[I3];
constexpr auto Ho = OutDesc::GetLengths()[I2];
constexpr auto Wo = OutDesc::GetLengths()[I3];
constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3];
constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{};
#if 0
// run-time variables
constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
constexpr auto wei_k_y_x_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C0));
constexpr auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K));
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{});
const auto in_left_pads = to_multi_index(InLeftPads{});
const auto in_right_pads = to_multi_index(InRightPads{});
#else
// compile-time variables
constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C0));
constexpr auto wei_k_y_x_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C0));
constexpr auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif
Tensor<TInWei> in_n_hi_wi_c(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
Tensor<TInWei> wei_k_y_x_c(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
Tensor<TOut> out_n_ho_wo_k(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
in_n_hi_wi_c(n, hi, wi, c) = in_n_c_hi_wi(n, c, hi, wi);
};
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
wei_k_y_x_c(k, y, x, c) = wei_k_c_y_x(k, c, y, x);
};
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
out_n_ho_wo_k(n, ho, wo, k) = out_n_k_ho_wo(n, k, ho, wo);
};
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)();
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 1
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
#elif 0
// cdata = 64, BlockSize = 64, 16x256x2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 1;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 1
// cdata = 64, BlockSize = 64, 16x256x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x4
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#endif
constexpr auto conv_driver =
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
#elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif
<BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
TOut,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
GemmBBlockTransferSrcScalarPerVector_GemmK,
GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmM1>{};
conv_driver.Run(wei_k_y_x_c0_desc,
in_n_hi_wi_c0_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()));
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_n_k_ho_wo(n, k, ho, wo) = out_n_ho_wo_k(n, ho, wo, k);
};
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)();
}
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
template <class TInWei,
ck::index_t InWeiVectorSize,
class TAcc,
class TOut,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
InDesc,
const Tensor<TInWei>& in_n_c_hi_wi,
WeiDesc,
const Tensor<TInWei>& wei_k_c_y_x,
OutDesc,
Tensor<TOut>& out_n_k_ho_wo,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw"
<< std::endl;
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto N = OutDesc::GetLengths()[I0];
constexpr auto K = OutDesc::GetLengths()[I1];
constexpr auto C = WeiDesc::GetLengths()[I1];
constexpr auto Hi = InDesc::GetLengths()[I2];
constexpr auto Wi = InDesc::GetLengths()[I3];
constexpr auto Ho = OutDesc::GetLengths()[I2];
constexpr auto Wo = OutDesc::GetLengths()[I3];
constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3];
constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{};
#if 0
// run-time variables
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{});
const auto in_left_pads = to_multi_index(InLeftPads{});
const auto in_right_pads = to_multi_index(InRightPads{});
#else
// compile-time variables
const auto in_n_c0_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
const auto wei_k_c0_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
in_n_c_hi_wi(n, c, hi, wi);
};
auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) {
wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) =
wei_k_c_y_x(k, c, y, x);
};
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
// cdata = 64, BlockSize = 64, 16x8x32x4
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, 16>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = 1;
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
TOut,
KPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
BThreadTransferSrcScalarPerVector_W,
CThreadTransferDstScalarPerVector_W>{};
conv_driver.Run(wei_k_c0_y_x_desc,
in_n_c0_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}
...@@ -273,7 +273,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw, ...@@ -273,7 +273,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t ho = HoPerTile * htile + j; std::size_t ho = HoPerTile * htile + j;
for(int i = 0; i < WoPerTile; ++i) for(int i = 0; i < WoPerTile; ++i)
{ {
std::size_t wo = WoPerTile * wtile + i; std::size_t wo = WoPerTile * wtile + i;
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
} }
} }
......
...@@ -158,7 +158,7 @@ struct ParallelTensorFunctor ...@@ -158,7 +158,7 @@ struct ParallelTensorFunctor
return indices; return indices;
} }
void operator()(std::size_t num_thread) const void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const
{ {
std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
......
...@@ -4,10 +4,7 @@ ...@@ -4,10 +4,7 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include "config.hpp" #include "config.hpp"
#include "tensor_descriptor.hpp" #include "print.hpp"
#include "tensor_descriptor_helper.hpp"
#include "print_array.hpp"
#include "print_sequence.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
...@@ -54,10 +51,10 @@ int main(int argc, char* argv[]) ...@@ -54,10 +51,10 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 128;
constexpr index_t HI = 28; constexpr index_t HI = 28;
constexpr index_t WI = 28; constexpr index_t WI = 28;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -156,13 +153,13 @@ int main(int argc, char* argv[]) ...@@ -156,13 +153,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 0 #elif 1
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 7; constexpr index_t X = 7;
...@@ -197,7 +194,7 @@ int main(int argc, char* argv[]) ...@@ -197,7 +194,7 @@ int main(int argc, char* argv[])
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>; using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<2, 2>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
...@@ -211,11 +208,11 @@ int main(int argc, char* argv[]) ...@@ -211,11 +208,11 @@ int main(int argc, char* argv[])
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
print_sequence("LeftPads", LeftPads{}); print_array("LeftPads", LeftPads{});
print_sequence("LeftPads", LeftPads{}); print_array("LeftPads", LeftPads{});
print_sequence("RightPads", RightPads{}); print_array("RightPads", RightPads{});
print_sequence("ConvStrides", ConvStrides{}); print_array("ConvStrides", ConvStrides{});
print_sequence("ConvDilations", ConvDilations{}); print_array("ConvDilations", ConvDilations{});
Tensor<float> in_nchw_device(make_HostTensorDescriptor(in_nchw_desc)); Tensor<float> in_nchw_device(make_HostTensorDescriptor(in_nchw_desc));
Tensor<float> in_nchw_host(make_HostTensorDescriptor(in_nchw_desc)); Tensor<float> in_nchw_host(make_HostTensorDescriptor(in_nchw_desc));
...@@ -248,7 +245,7 @@ int main(int argc, char* argv[]) ...@@ -248,7 +245,7 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0 #elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1 #elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
......
conv_bwd_data_driver.cpp
\ No newline at end of file
...@@ -5,36 +5,167 @@ ...@@ -5,36 +5,167 @@
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "print_array.hpp" #include "print.hpp"
#include "print_sequence.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 0
// 1x1, 17x17 constexpr index_t N = 1;
constexpr index_t N = 128; constexpr index_t C = 16;
constexpr index_t C = 1024; constexpr index_t HI = 1080;
constexpr index_t HI = 17; constexpr index_t WI = 1920;
constexpr index_t WI = 17; constexpr index_t K = 16;
constexpr index_t K = 256; constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 540;
constexpr index_t WI = 960;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 270;
constexpr index_t WI = 480;
constexpr index_t K = 16;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 1;
constexpr index_t HI = 1024;
constexpr index_t WI = 2048;
constexpr index_t K = 4;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 540;
constexpr index_t WI = 960;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 270;
constexpr index_t WI = 480;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 36x36, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 37;
constexpr index_t WI = 37;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 1x1, 8x8 // 1x1, 8x8
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1536; constexpr index_t C = 1536;
...@@ -70,7 +201,7 @@ int main(int argc, char* argv[]) ...@@ -70,7 +201,7 @@ int main(int argc, char* argv[])
constexpr index_t C = 96; constexpr index_t C = 96;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 96; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -94,7 +225,7 @@ int main(int argc, char* argv[]) ...@@ -94,7 +225,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
// 7x1, 17x17 // 7x1, 17x17
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -109,7 +240,7 @@ int main(int argc, char* argv[]) ...@@ -109,7 +240,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 1 #elif 0
// 1x7, 17x17 // 1x7, 17x17
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -141,12 +272,11 @@ int main(int argc, char* argv[]) ...@@ -141,12 +272,11 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 3x3, 147x147 // 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 32; constexpr index_t C = 128;
constexpr index_t HI = 147; constexpr index_t HI = 147;
constexpr index_t WI = 147; constexpr index_t WI = 147;
constexpr index_t K = 64; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -157,7 +287,6 @@ int main(int argc, char* argv[]) ...@@ -157,7 +287,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
// 3x3, 149x149 // 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 32; constexpr index_t C = 32;
constexpr index_t HI = 149; constexpr index_t HI = 149;
...@@ -201,7 +330,7 @@ int main(int argc, char* argv[]) ...@@ -201,7 +330,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 35x35, stride 2 // 3x3, 35x35, stride 2
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 288; constexpr index_t C = 288;
...@@ -244,21 +373,6 @@ int main(int argc, char* argv[]) ...@@ -244,21 +373,6 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>; using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>; using RightPads = Sequence<1, 0>;
#elif 0 #elif 0
...@@ -278,7 +392,6 @@ int main(int argc, char* argv[]) ...@@ -278,7 +392,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 7x1, 73x73 // 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 64; constexpr index_t C = 64;
constexpr index_t HI = 73; constexpr index_t HI = 73;
...@@ -352,7 +465,7 @@ int main(int argc, char* argv[]) ...@@ -352,7 +465,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -382,7 +495,7 @@ int main(int argc, char* argv[]) ...@@ -382,7 +495,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
// 1x1, 56x56, stride 2 // 1x1, 56x56, stride 2
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -442,7 +555,7 @@ int main(int argc, char* argv[]) ...@@ -442,7 +555,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 1x1, 7x7 // 1x1, 7x7
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
...@@ -472,7 +585,7 @@ int main(int argc, char* argv[]) ...@@ -472,7 +585,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
// 1x1, 56x56 // 1x1, 56x56
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 64; constexpr index_t C = 64;
...@@ -487,7 +600,7 @@ int main(int argc, char* argv[]) ...@@ -487,7 +600,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 56x56 // 3x3, 56x56
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 64; constexpr index_t C = 64;
...@@ -512,17 +625,26 @@ int main(int argc, char* argv[]) ...@@ -512,17 +625,26 @@ int main(int argc, char* argv[])
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
print_sequence("LeftPads", LeftPads{}); print_array("LeftPads", to_multi_index(LeftPads{}));
print_sequence("RightPads", RightPads{}); print_array("RightPads", to_multi_index(RightPads{}));
print_sequence("ConvStrides", ConvStrides{}); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_sequence("ConvDilations", ConvDilations{}); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
using out_data_t = float; constexpr index_t in_vector_size = 1;
#else using acc_data_t = float;
using in_data_t = half_float::half; using out_data_t = float;
using out_data_t = half_float::half; #elif 0
using in_data_t = float;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using out_data_t = int8_t;
#elif 1
using in_data_t = int8_t;
constexpr index_t in_vector_size = 4;
using acc_data_t = int32_t;
using out_data_t = int8_t;
#endif #endif
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc)); Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
...@@ -532,14 +654,15 @@ int main(int argc, char* argv[]) ...@@ -532,14 +654,15 @@ int main(int argc, char* argv[])
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
if(argc != 3) if(argc != 4)
{ {
printf("arg1: do_verification, arg2: nrepeat\n"); printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
exit(1); exit(1);
} }
bool do_verification = atoi(argv[1]); bool do_verification = atoi(argv[1]);
index_t nrepeat = atoi(argv[2]); bool do_log = atoi(argv[2]);
index_t nrepeat = atoi(argv[3]);
if(do_verification) if(do_verification)
{ {
...@@ -548,9 +671,9 @@ int main(int argc, char* argv[]) ...@@ -548,9 +671,9 @@ int main(int argc, char* argv[])
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_3{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
...@@ -565,59 +688,112 @@ int main(int argc, char* argv[]) ...@@ -565,59 +688,112 @@ int main(int argc, char* argv[])
#endif #endif
} }
#if 1 #if 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
out_nkhw_device, out_nkhw_device,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 0
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size,
acc_data_t,
out_data_t>
(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 1 #elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_nchw, in_vector_size,
wei_kcyx_desc, acc_data_t,
wei_kcyx, out_data_t>
out_nkhw_desc,
out_nkhw_device, (in_nchw_desc,
ConvStrides{}, in_nchw,
ConvDilations{}, wei_kcyx_desc,
LeftPads{}, wei_kcyx,
RightPads{}, out_nkhw_desc,
nrepeat); out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
in_vector_size,
acc_data_t,
out_data_t>(
in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#endif #endif
if(do_verification) if(do_verification)
{ {
#if 0 host_direct_convolution(in_nchw,
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 && wei_kcyx,
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1) out_nkhw_host,
{ ConvStrides{},
host_winograd_3x3_convolution( ConvDilations{},
in_nchw, wei_kcyx, out_nkhw_host, LeftPads{}, RightPads{}); LeftPads{},
} RightPads{});
else
#endif
{
host_direct_convolution(in_nchw,
wei_kcyx,
out_nkhw_host,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{});
}
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#if 0 if(do_log)
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; {
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
#endif LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
}
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment