"docs/vscode:/vscode.git/clone" did not exist on "71bcaf99e2cb2c677bf3a9addb9e8039cbcab22a"
Commit aa5859e4 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into wavelet_model

parents 9bd6cc0e 5ee30459
...@@ -144,10 +144,18 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -144,10 +144,18 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return min(x, min(ys...)); return min(x, min(ys...));
} }
template <typename T>
__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
{
return min(max(x, lowerbound), upperbound);
}
// disallow implicit type casting // disallow implicit type casting
template <typename T> template <typename T>
__device__ T exp(T x); __device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <> template <>
__device__ float exp<float>(float x) __device__ float exp<float>(float x)
{ {
......
...@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore ...@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore
{ {
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
if(!isnan(currVal)) if(!ck::math::isnan(currVal))
{ {
ReduceOperation{}(accuVal, currVal); ReduceOperation{}(accuVal, currVal);
} }
......
...@@ -58,6 +58,33 @@ struct Add ...@@ -58,6 +58,33 @@ struct Add
} }
}; };
struct SquaredAdd
{
template <class T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
template <class T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
a = a + b * b;
}
};
struct Mul struct Mul
{ {
template <typename T> template <typename T>
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
#pragma once #pragma once
#include "integral_constant.hpp" #include "ck/utility/integral_constant.hpp"
#include "type.hpp" #include "ck/utility/type.hpp"
#include "functional.hpp" #include "ck/utility/functional.hpp"
#include "math.hpp" #include "ck/utility/math.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_SEQUENCE_HELPER_HPP #pragma once
#define CK_SEQUENCE_HELPER_HPP
#include "tuple.hpp" #include "ck/utility/tuple.hpp"
namespace ck { namespace ck {
...@@ -36,4 +35,3 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>) ...@@ -36,4 +35,3 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
} }
} // namespace ck } // namespace ck
#endif
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATIC_BUFFER_HPP #pragma once
#define CK_STATIC_BUFFER_HPP
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
...@@ -20,6 +19,22 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -20,6 +19,22 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
template <typename... Ys>
__host__ __device__ constexpr StaticBuffer& operator=(const Tuple<Ys...>& y)
{
static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same");
StaticBuffer& x = *this;
static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; });
return x;
}
__host__ __device__ constexpr StaticBuffer& operator=(const T& y)
{
StaticBuffer& x = *this;
static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y; });
return x;
}
__host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
...@@ -40,10 +55,12 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -40,10 +55,12 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
return base::operator()(i); return base::operator()(i);
} }
__host__ __device__ void Clear() __host__ __device__ void Set(T x)
{ {
static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; }); static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{x}; });
} }
__host__ __device__ void Clear() { Set(T{0}); }
}; };
// static buffer for vector // static buffer for vector
...@@ -61,6 +78,7 @@ struct StaticBufferTupleOfVector ...@@ -61,6 +78,7 @@ struct StaticBufferTupleOfVector
static constexpr auto s_per_v = Number<ScalarPerVector>{}; static constexpr auto s_per_v = Number<ScalarPerVector>{};
static constexpr auto num_of_v_ = Number<NumOfVector>{}; static constexpr auto num_of_v_ = Number<NumOfVector>{};
static constexpr auto s_per_buf = s_per_v * num_of_v_;
__host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {} __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
...@@ -70,6 +88,8 @@ struct StaticBufferTupleOfVector ...@@ -70,6 +88,8 @@ struct StaticBufferTupleOfVector
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
__host__ __device__ static constexpr index_t Size() { return s_per_buf; };
// Get S // Get S
// i is offset of S // i is offset of S
template <index_t I> template <index_t I>
...@@ -173,4 +193,3 @@ __host__ __device__ constexpr auto make_static_buffer(LongNumber<N>) ...@@ -173,4 +193,3 @@ __host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
} }
} // namespace ck } // namespace ck
#endif
...@@ -34,7 +34,10 @@ __host__ __device__ constexpr auto to_multi_index(const T& x) ...@@ -34,7 +34,10 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if // is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize> // using MultiIndex<NSize>
// TODO: how to fix this? // TODO: how to fix this?
template <typename... Ys, typename X> template <
typename... Ys,
typename X,
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x) __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
...@@ -43,7 +46,10 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x) ...@@ -43,7 +46,10 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return y; return y;
} }
template <typename... Ys, typename X> template <
typename... Ys,
typename X,
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x) __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
...@@ -52,7 +58,10 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x) ...@@ -52,7 +58,10 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return y; return y;
} }
template <typename... Xs, typename Y> template <
typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -63,7 +72,10 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y) ...@@ -63,7 +72,10 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename Y> template <
typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -74,7 +86,10 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y) ...@@ -74,7 +86,10 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename Y> template <
typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -85,9 +100,11 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) ...@@ -85,9 +100,11 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
// MultiIndex = index_t * MultiIndex // MultiIndex = scalar * MultiIndex
template <typename... Xs> template <typename... Xs,
__host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x) typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
{ {
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
...@@ -96,13 +113,40 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x) ...@@ -96,13 +113,40 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return r; return r;
} }
// MultiIndex = MultiIndex * index_t // MultiIndex = MultiIndex * scalar
template <typename... Xs> template <typename... Xs,
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, index_t a) typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, Y a)
{ {
return a * x; return a * x;
} }
namespace mathext {
template <typename... Xs>
__host__ __device__ constexpr auto exp(const Tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::exp(x[i]); });
return r;
}
template <typename... Xs, typename Y>
__host__ __device__ constexpr auto max(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::max(x[i], y[i]); });
return r;
}
} // namespace mathext
template <typename... Xs> template <typename... Xs>
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x) __host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
{ {
......
...@@ -18,14 +18,15 @@ __device__ void block_sync_lds() ...@@ -18,14 +18,15 @@ __device__ void block_sync_lds()
__syncthreads(); __syncthreads();
#endif #endif
} }
__device__ void block_lds()
__device__ void s_nop()
{ {
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if 1
asm volatile("\ asm volatile("\
s_waitcnt lgkmcnt(0) \ s_nop 0 \n \
" ::); " ::);
#else #else
__syncthreads(); __builtin_amdgcn_sched_barrier(0);
#endif #endif
} }
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
#pragma once #pragma once
#include "integral_constant.hpp" #include "ck/utility/integral_constant.hpp"
#include "sequence.hpp" #include "ck/utility/sequence.hpp"
#include "type.hpp" #include "ck/utility/type.hpp"
#include "enable_if.hpp" #include "ck/utility/enable_if.hpp"
namespace ck { namespace ck {
...@@ -21,6 +21,8 @@ struct TupleElementKey ...@@ -21,6 +21,8 @@ struct TupleElementKey
template <typename Key, typename Data> template <typename Key, typename Data>
struct TupleElementKeyData struct TupleElementKeyData
{ {
using DataType = Data;
#if 0 // workaround compiler complaint about implicitly-deleted default constructor #if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default; __host__ __device__ constexpr TupleElementKeyData() = default;
#else #else
...@@ -34,29 +36,40 @@ struct TupleElementKeyData ...@@ -34,29 +36,40 @@ struct TupleElementKeyData
{ {
} }
Data mData; DataType mData;
}; };
// for read access of tuple element
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr const Data& __host__ __device__ constexpr const Data&
get_tuple_element_data(const TupleElementKeyData<Key, Data>& x) get_tuple_element_data_reference(const TupleElementKeyData<Key, Data>& x)
{ {
return static_cast<const Data&>(x.mData); return static_cast<const Data&>(x.mData);
} }
// for write access of tuple element
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData<Key, Data>& x) __host__ __device__ constexpr Data&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>& x)
{ {
return x.mData; return x.mData;
} }
// TODO: not sure the use of reference is correct // TODO: not sure the use of reference is correct
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData<Key, Data>&& x) __host__ __device__ constexpr Data&&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
{ {
return static_cast<Data&&>(x.mData); return static_cast<Data&&>(x.mData);
} }
// for infering type of tuple element
template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
return std::forward(x.mData);
}
template <typename Indices, typename... Xs> template <typename Indices, typename... Xs>
struct TupleImpl; struct TupleImpl;
...@@ -87,13 +100,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I ...@@ -87,13 +100,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
{ {
return get_tuple_element_data<TupleElementKey<I>>(*this); return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>) __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
{ {
return get_tuple_element_data<TupleElementKey<I>>(*this); return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
} }
}; };
...@@ -185,7 +198,8 @@ struct Tuple<> ...@@ -185,7 +198,8 @@ struct Tuple<>
template <index_t I, typename TTuple> template <index_t I, typename TTuple>
struct tuple_element struct tuple_element
{ {
using type = decltype(TTuple{}.At(Number<I>{})); // type should keep the cv/ref qualifier of original tuple element
using type = decltype(detail::get_tuple_element_data<detail::TupleElementKey<I>>(TTuple{}));
}; };
template <index_t I, typename TTuple> template <index_t I, typename TTuple>
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "integral_constant.hpp" #include "ck/utility/integral_constant.hpp"
#include "enable_if.hpp" #include "ck/utility/enable_if.hpp"
namespace ck { namespace ck {
......
add_subdirectory(src/tensor_operation_instance/gpu) add_subdirectory(src/tensor_operation_instance/gpu)
add_subdirectory(src/host_tensor)
add_subdirectory(src/utility) add_subdirectory(src/utility)
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -16,6 +16,7 @@ namespace host { ...@@ -16,6 +16,7 @@ namespace host {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
float v_acc = 0; AccDataType v_acc = 0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
...@@ -68,10 +69,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -68,10 +69,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += ck::type_convert<float>(v_a) * ck::type_convert<float>(v_b); v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
float v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
...@@ -81,8 +83,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -81,8 +83,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
make_ParallelTensorFunctor(f_gmk_gkn_gmn, make_ParallelTensorFunctor(f_gmk_gkn_gmn,
arg.c_g_m_n_.mDesc.GetLengths()[0], arg.c_g_m_n_.mDesc.GetLengths()[0],
arg.c_g_m_n_.mDesc.GetLengths()[1], arg.c_g_m_n_.mDesc.GetLengths()[1],
arg.c_g_m_n_.mDesc.GetLengths()[2])( arg.c_g_m_n_.mDesc.GetLengths()[2])();
std::thread::hardware_concurrency());
return 0; return 0;
} }
......
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -91,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -91,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator
v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag; v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag;
} }
arg.c_m_n_real_(m, n) = v_c_real; arg.c_m_n_real_(m, n) = ck::type_convert<CDataType>(v_c_real);
}; };
auto f_mk_kn_mn_imag = [&](auto m, auto n) { auto f_mk_kn_mn_imag = [&](auto m, auto n) {
...@@ -107,7 +108,7 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -107,7 +108,7 @@ struct ReferenceCGemm : public device::BaseOperator
v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real; v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real;
} }
arg.c_m_n_imag_(m, n) = v_c_imag; arg.c_m_n_imag_(m, n) = ck::type_convert<CDataType>(v_c_imag);
}; };
make_ParallelTensorFunctor(f_mk_kn_mn_real, make_ParallelTensorFunctor(f_mk_kn_mn_real,
......
...@@ -8,22 +8,24 @@ ...@@ -8,22 +8,24 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] // input descriptor in [G, N, C, Do, Ho, Wo] order
template <typename InDataType, // weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator struct ReferenceConvBwdData : public device::BaseOperator
{ {
// Argument // Argument
...@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if constexpr(NumDimSpatial == 1) if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
auto f_ncw = [&](auto n, auto c, auto wi) { throw std::runtime_error("wrong! inconsistent dimension");
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; }
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[2]; if constexpr(NDimSpatial == 1)
{
auto f_ncw = [&](auto g, auto n, auto c, auto wi) {
std::size_t K = arg.weight_.GetLengths()[1];
std::size_t X = arg.weight_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[3];
AccDataType v_acc = 0; float v_acc = 0;
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = ck::type_convert<ck::long_index_t>(wi) + auto w_tmp = static_cast<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0) if(w_tmp % arg.conv_strides_[0] == 0)
{ {
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) / auto wo = static_cast<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; float v_out = 0;
AccDataType v_wei = 0; float v_wei = 0;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
ck::type_convert<AccDataType>(arg.output_(n, k, wo)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, ck::type_convert<AccDataType>(arg.weight_(k, c, x))); v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
arg.in_element_op_(v_acc, v_acc); float v_in;
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
arg.input_.mDesc.GetLengths()[0], arg.input_.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2])( arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { auto f_nchw = [&](auto g, auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Y = arg.weight_.mDesc.GetLengths()[2]; std::size_t Y = arg.weight_.GetLengths()[3];
std::size_t X = arg.weight_.mDesc.GetLengths()[3]; std::size_t X = arg.weight_.GetLengths()[4];
std::size_t Ho = arg.output_.mDesc.GetLengths()[2]; std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = arg.output_.mDesc.GetLengths()[3]; std::size_t Wo = arg.output_.GetLengths()[4];
AccDataType v_acc = 0; float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
auto h_tmp = ck::type_convert<ck::long_index_t>(hi) + auto h_tmp = static_cast<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0) if(h_tmp % arg.conv_strides_[0] == 0)
{ {
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) / auto ho = static_cast<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho) if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = auto w_tmp =
ck::type_convert<ck::long_index_t>(wi) + static_cast<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x * static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]);
arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0) if(w_tmp % arg.conv_strides_[1] == 0)
{ {
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) / auto wo =
ck::type_convert<ck::long_index_t>( static_cast<ck::long_index_t>(w_tmp) /
arg.conv_strides_[1]); static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; float v_out = 0;
AccDataType v_wei = 0; float v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<float>(
arg.output_(g, n, k, ho, wo)));
arg.out_element_op_(v_out, arg.wei_element_op_(
ck::type_convert<AccDataType>( v_wei,
arg.output_(n, k, ho, wo))); ck::type_convert<float>(
arg.wei_element_op_(v_wei, arg.weight_(g, k, c, y, x)));
ck::type_convert<AccDataType>(
arg.weight_(k, c, y, x)));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
AccDataType v_in; float v_in;
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
arg.input_.mDesc.GetLengths()[0], arg.input_.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2], arg.input_.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3])( arg.input_.GetLengths()[3],
arg.input_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) { auto f_ncdhw = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Z = arg.weight_.mDesc.GetLengths()[2]; std::size_t Z = arg.weight_.GetLengths()[3];
std::size_t Y = arg.weight_.mDesc.GetLengths()[3]; std::size_t Y = arg.weight_.GetLengths()[4];
std::size_t X = arg.weight_.mDesc.GetLengths()[4]; std::size_t X = arg.weight_.GetLengths()[5];
std::size_t Do = arg.output_.mDesc.GetLengths()[2]; std::size_t Do = arg.output_.GetLengths()[3];
std::size_t Ho = arg.output_.mDesc.GetLengths()[3]; std::size_t Ho = arg.output_.GetLengths()[4];
std::size_t Wo = arg.output_.mDesc.GetLengths()[4]; std::size_t Wo = arg.output_.GetLengths()[5];
AccDataType v_acc = 0; float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z) for(std::size_t z = 0; z < Z; ++z)
{ {
auto d_tmp = ck::type_convert<ck::long_index_t>(di) + auto d_tmp = static_cast<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0) if(d_tmp % arg.conv_strides_[0] == 0)
{ {
auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) / auto do_ = static_cast<ck::long_index_t>(d_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do) if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{ {
for(std::size_t y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
auto h_tmp = auto h_tmp =
ck::type_convert<ck::long_index_t>(hi) + static_cast<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y * static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]);
arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0) if(h_tmp % arg.conv_strides_[1] == 0)
{ {
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) / auto ho =
ck::type_convert<ck::long_index_t>( static_cast<ck::long_index_t>(h_tmp) /
arg.conv_strides_[1]); static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho) if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = auto w_tmp = static_cast<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(wi) + static_cast<ck::long_index_t>(
ck::type_convert<ck::long_index_t>( arg.in_left_pads_[2]) -
arg.in_left_pads_[2]) - static_cast<ck::long_index_t>(
ck::type_convert<ck::long_index_t>( x * arg.conv_dilations_[2]);
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0) if(w_tmp % arg.conv_strides_[2] == 0)
{ {
auto wo = auto wo = static_cast<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(w_tmp) / static_cast<ck::long_index_t>(
ck::type_convert<ck::long_index_t>( arg.conv_strides_[2]);
arg.conv_strides_[2]);
if(wo >= 0 && if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo) ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; float v_out = 0;
AccDataType v_wei = 0; float v_wei = 0;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
ck::type_convert<AccDataType>( ck::type_convert<float>(arg.output_(
arg.output_( g, n, k, do_, ho, wo)));
n, k, do_, ho, wo)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<AccDataType>( ck::type_convert<float>(
arg.weight_(k, c, z, y, x))); arg.weight_(g, k, c, z, y, x)));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
AccDataType v_in; float v_in;
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
arg.input_.mDesc.GetLengths()[0], arg.input_.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2], arg.input_.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3], arg.input_.GetLengths()[3],
arg.input_.mDesc.GetLengths()[4])( arg.input_.GetLengths()[4],
arg.input_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -7,21 +7,25 @@ ...@@ -7,21 +7,25 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] // input descriptor in [G, N, C, Do, Ho, Wo] order
template <typename InDataType, // weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
// Argument // Argument
...@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if constexpr(NumDimSpatial == 1) if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
constexpr auto I0 = Number<0>{}; throw std::runtime_error("wrong! inconsistent dimension");
auto f_kcx = [&](auto k, auto c, auto x) { }
if constexpr(NDimSpatial == 1)
{
auto f_kcx = [&](auto g, auto k, auto c, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo) for(std::size_t wo = 0; wo < arg.output_.GetLengths()[3]; ++wo)
{ {
auto wi = auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_(v_out, arg.out_element_op_(
ck::type_convert<float>(arg.output_(n, k, wo))); v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, wi))); arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
float v_wei; float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei); arg.weight_(g, k, c, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcx, make_ParallelTensorFunctor(f_kcx,
arg.weight_.mDesc.GetLengths()[0], arg.weight_.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1], arg.weight_.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2])( arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
constexpr auto I0 = Number<0>{}; auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho) for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho)
{ {
auto hi = auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) + static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo) for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
arg.conv_dilations_[I1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
arg.input_.mDesc.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo))); v_out,
ck::type_convert<float>(arg.output_(g, n, k, ho, wo)));
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
} }
float v_wei; float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei); arg.weight_(g, k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcyx, make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0], arg.weight_.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1], arg.weight_.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2], arg.weight_.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])( arg.weight_.GetLengths()[3],
arg.weight_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
constexpr auto I0 = Number<0>{}; auto f_kczyx = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_) for(std::size_t do_ = 0; do_ < arg.output_.GetLengths()[3]; ++do_)
{ {
auto di = auto di = static_cast<ck::long_index_t>(do_ * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) + static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]); for(std::size_t ho = 0; ho < arg.output_.GetLengths()[4]; ++ho)
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
{ {
auto hi = auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
arg.conv_dilations_[I1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]); for(std::size_t wo = 0; wo < arg.output_.GetLengths()[5]; ++wo)
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
arg.conv_strides_[I2]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>( static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] && arg.input_.GetLengths()[3] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] && arg.input_.GetLengths()[4] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4]) arg.input_.GetLengths()[5])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_(v_out, arg.out_element_op_(v_out,
ck::type_convert<float>( ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo))); arg.output_(g, n, k, do_, ho, wo)));
arg.in_element_op_(
v_in, arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi))); ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
...@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
} }
} }
} }
float v_wei; float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei); arg.weight_(g, k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kczyx, make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0], arg.weight_.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1], arg.weight_.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2], arg.weight_.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3], arg.weight_.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])( arg.weight_.GetLengths()[4],
arg.weight_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -17,9 +17,10 @@ namespace host { ...@@ -17,9 +17,10 @@ namespace host {
// //
// @brief Reference implementation for forward convolution. // @brief Reference implementation for forward convolution.
// //
// @paragraph Supports both NCHW as well as NHWC formats (and their respective // @paragraph
// counterparts for weight and output) as long as tensor descriptor // Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// lengths is in NCHW. // Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
// //
// @tparam InDataType Input tensor data type. // @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type. // @tparam WeiDataType Weights tensor data type.
...@@ -28,16 +29,20 @@ namespace host { ...@@ -28,16 +29,20 @@ namespace host {
// operation. // operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise // @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation. // operation.
// @tparam NumDimSpatial Number of spatial dimensions. // @tparam NDimSpatial Number of spatial dimensions.
// //
template <typename InDataType, // input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator struct ReferenceConvFwd : public device::BaseOperator
{ {
// Argument // Argument
...@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if constexpr(NumDimSpatial == 1) if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
auto f_ncw = [&](auto n, auto k, auto wo) { throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x)
{ {
auto wi = auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_(v_in, arg.in_element_op_(
ck::type_convert<float>(arg.input_(n, c, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
arg.wei_element_op_(v_wei,
ck::type_convert<float>(arg.weight_(k, c, x))); arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(func,
arg.output_.mDesc.GetLengths()[0], arg.output_.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], arg.output_.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2])( arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y)
{ {
auto hi = auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
arg.input_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x))); v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
} }
...@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
arg.output_.mDesc.GetLengths()[0], arg.output_.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], arg.output_.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2], arg.output_.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3])( arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z)
{ {
auto di = auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y)
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{ {
auto hi = auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
arg.conv_strides_[2]) + static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(x * static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] && arg.input_.GetLengths()[3] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] && arg.input_.GetLengths()[4] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4]) arg.input_.GetLengths()[5])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_( arg.in_element_op_(v_in,
v_in, ck::type_convert<float>(
ck::type_convert<float>(arg.input_(n, c, di, hi, wi))); arg.input_(g, n, c, di, hi, wi)));
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<float>(arg.weight_(k, c, z, y, x))); ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
} }
...@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
arg.output_.mDesc.GetLengths()[0], arg.output_.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], arg.output_.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2], arg.output_.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3], arg.output_.GetLengths()[3],
arg.output_.mDesc.GetLengths()[4])( arg.output_.GetLengths()[4],
arg.output_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator
return true; return true;
} }
bool IsSupportedArgument(const device::BaseArgument*) override { return true; } bool IsSupportedArgument(const device::BaseArgument*) override
{
return NDimSpatial >= 1 && NDimSpatial <= 3;
}
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment