Commit aa5859e4 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into wavelet_model

parents 9bd6cc0e 5ee30459
......@@ -144,10 +144,18 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return min(x, min(ys...));
}
template <typename T>
__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
{
return min(max(x, lowerbound), upperbound);
}
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <>
__device__ float exp<float>(float x)
{
......
......@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore
{
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{
if(!isnan(currVal))
if(!ck::math::isnan(currVal))
{
ReduceOperation{}(accuVal, currVal);
}
......
......@@ -58,6 +58,33 @@ struct Add
}
};
struct SquaredAdd
{
template <class T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
template <class T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
a = a + b * b;
}
};
struct Mul
{
template <typename T>
......
......@@ -3,10 +3,10 @@
#pragma once
#include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp"
#include "math.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/math.hpp"
namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#pragma once
#include "tuple.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
......@@ -36,4 +35,3 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
#pragma once
#include "statically_indexed_array.hpp"
......@@ -20,6 +19,22 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
__host__ __device__ constexpr StaticBuffer() : base{} {}
template <typename... Ys>
__host__ __device__ constexpr StaticBuffer& operator=(const Tuple<Ys...>& y)
{
static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same");
StaticBuffer& x = *this;
static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; });
return x;
}
__host__ __device__ constexpr StaticBuffer& operator=(const T& y)
{
StaticBuffer& x = *this;
static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y; });
return x;
}
__host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
......@@ -40,10 +55,12 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
return base::operator()(i);
}
__host__ __device__ void Clear()
__host__ __device__ void Set(T x)
{
static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; });
static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{x}; });
}
__host__ __device__ void Clear() { Set(T{0}); }
};
// static buffer for vector
......@@ -61,6 +78,7 @@ struct StaticBufferTupleOfVector
static constexpr auto s_per_v = Number<ScalarPerVector>{};
static constexpr auto num_of_v_ = Number<NumOfVector>{};
static constexpr auto s_per_buf = s_per_v * num_of_v_;
__host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
......@@ -70,6 +88,8 @@ struct StaticBufferTupleOfVector
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
__host__ __device__ static constexpr index_t Size() { return s_per_buf; };
// Get S
// i is offset of S
template <index_t I>
......@@ -173,4 +193,3 @@ __host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
}
} // namespace ck
#endif
......@@ -34,7 +34,10 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template <typename... Ys, typename X>
template <
typename... Ys,
typename X,
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
......@@ -43,7 +46,10 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return y;
}
template <typename... Ys, typename X>
template <
typename... Ys,
typename X,
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
......@@ -52,7 +58,10 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return y;
}
template <typename... Xs, typename Y>
template <
typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
......@@ -63,7 +72,10 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return r;
}
template <typename... Xs, typename Y>
template <
typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
......@@ -74,7 +86,10 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return r;
}
template <typename... Xs, typename Y>
template <
typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
......@@ -85,9 +100,11 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
return r;
}
// MultiIndex = index_t * MultiIndex
template <typename... Xs>
__host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
// MultiIndex = scalar * MultiIndex
template <typename... Xs,
typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
......@@ -96,13 +113,40 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return r;
}
// MultiIndex = MultiIndex * index_t
template <typename... Xs>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, index_t a)
// MultiIndex = MultiIndex * scalar
template <typename... Xs,
typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, Y a)
{
return a * x;
}
namespace mathext {
template <typename... Xs>
__host__ __device__ constexpr auto exp(const Tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::exp(x[i]); });
return r;
}
template <typename... Xs, typename Y>
__host__ __device__ constexpr auto max(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::max(x[i], y[i]); });
return r;
}
} // namespace mathext
template <typename... Xs>
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
{
......
......@@ -18,14 +18,15 @@ __device__ void block_sync_lds()
__syncthreads();
#endif
}
__device__ void block_lds()
__device__ void s_nop()
{
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#if 1
asm volatile("\
s_waitcnt lgkmcnt(0) \
s_nop 0 \n \
" ::);
#else
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
#endif
}
......
......@@ -3,10 +3,10 @@
#pragma once
#include "integral_constant.hpp"
#include "sequence.hpp"
#include "type.hpp"
#include "enable_if.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/enable_if.hpp"
namespace ck {
......@@ -21,6 +21,8 @@ struct TupleElementKey
template <typename Key, typename Data>
struct TupleElementKeyData
{
using DataType = Data;
#if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default;
#else
......@@ -34,29 +36,40 @@ struct TupleElementKeyData
{
}
Data mData;
DataType mData;
};
// for read access of tuple element
template <typename Key, typename Data>
__host__ __device__ constexpr const Data&
get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
get_tuple_element_data_reference(const TupleElementKeyData<Key, Data>& x)
{
return static_cast<const Data&>(x.mData);
}
// for write access of tuple element
template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData<Key, Data>& x)
__host__ __device__ constexpr Data&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>& x)
{
return x.mData;
}
// TODO: not sure the use of reference is correct
template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData<Key, Data>&& x)
__host__ __device__ constexpr Data&&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
{
return static_cast<Data&&>(x.mData);
}
// for infering type of tuple element
template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
return std::forward(x.mData);
}
template <typename Indices, typename... Xs>
struct TupleImpl;
......@@ -87,13 +100,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
template <index_t I>
__host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
{
return get_tuple_element_data<TupleElementKey<I>>(*this);
return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
}
template <index_t I>
__host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
{
return get_tuple_element_data<TupleElementKey<I>>(*this);
return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
}
};
......@@ -185,7 +198,8 @@ struct Tuple<>
template <index_t I, typename TTuple>
struct tuple_element
{
using type = decltype(TTuple{}.At(Number<I>{}));
// type should keep the cv/ref qualifier of original tuple element
using type = decltype(detail::get_tuple_element_data<detail::TupleElementKey<I>>(TTuple{}));
};
template <index_t I, typename TTuple>
......
......@@ -4,8 +4,8 @@
#pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp"
#include "enable_if.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/enable_if.hpp"
namespace ck {
......
add_subdirectory(src/tensor_operation_instance/gpu)
add_subdirectory(src/host_tensor)
add_subdirectory(src/utility)
......@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
......@@ -16,6 +16,7 @@ namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
......@@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
float v_acc = 0;
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
......@@ -68,10 +69,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += ck::type_convert<float>(v_a) * ck::type_convert<float>(v_b);
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
float v_c;
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
......@@ -81,8 +83,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
make_ParallelTensorFunctor(f_gmk_gkn_gmn,
arg.c_g_m_n_.mDesc.GetLengths()[0],
arg.c_g_m_n_.mDesc.GetLengths()[1],
arg.c_g_m_n_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency());
arg.c_g_m_n_.mDesc.GetLengths()[2])();
return 0;
}
......
......@@ -6,8 +6,9 @@
#include <iostream>
#include <sstream>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
......@@ -91,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator
v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag;
}
arg.c_m_n_real_(m, n) = v_c_real;
arg.c_m_n_real_(m, n) = ck::type_convert<CDataType>(v_c_real);
};
auto f_mk_kn_mn_imag = [&](auto m, auto n) {
......@@ -107,7 +108,7 @@ struct ReferenceCGemm : public device::BaseOperator
v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real;
}
arg.c_m_n_imag_(m, n) = v_c_imag;
arg.c_m_n_imag_(m, n) = ck::type_convert<CDataType>(v_c_imag);
};
make_ParallelTensorFunctor(f_mk_kn_mn_real,
......
......@@ -8,22 +8,24 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template <typename InDataType,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator
{
// Argument
......@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator
float Run(const Argument& arg)
{
if constexpr(NumDimSpatial == 1)
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto f_ncw = [&](auto g, auto n, auto c, auto wi) {
std::size_t K = arg.weight_.GetLengths()[1];
std::size_t X = arg.weight_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[3];
AccDataType v_acc = 0;
float v_acc = 0;
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]);
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0)
{
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
float v_out = 0;
float v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<AccDataType>(arg.output_(n, k, wo)));
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
arg.wei_element_op_(
v_wei, ck::type_convert<AccDataType>(arg.weight_(k, c, x)));
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_out * v_wei;
}
......@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
arg.in_element_op_(v_acc, v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2])(
arg.input_.GetLengths()[0],
arg.input_.GetLengths()[1],
arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
else if constexpr(NDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Y = arg.weight_.mDesc.GetLengths()[2];
std::size_t X = arg.weight_.mDesc.GetLengths()[3];
auto f_nchw = [&](auto g, auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Y = arg.weight_.GetLengths()[3];
std::size_t X = arg.weight_.GetLengths()[4];
std::size_t Ho = arg.output_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[3];
std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[4];
AccDataType v_acc = 0;
float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp = ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]);
auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0)
{
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp =
ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[1]);
static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0)
{
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
auto wo =
static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
float v_out = 0;
float v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<float>(
arg.output_(g, n, k, ho, wo)));
arg.out_element_op_(v_out,
ck::type_convert<AccDataType>(
arg.output_(n, k, ho, wo)));
arg.wei_element_op_(v_wei,
ck::type_convert<AccDataType>(
arg.weight_(k, c, y, x)));
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(
arg.weight_(g, k, c, y, x)));
v_acc += v_out * v_wei;
}
......@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
AccDataType v_in;
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
};
make_ParallelTensorFunctor(f_nchw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3])(
arg.input_.GetLengths()[0],
arg.input_.GetLengths()[1],
arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3],
arg.input_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
else if constexpr(NDimSpatial == 3)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
std::size_t X = arg.weight_.mDesc.GetLengths()[4];
auto f_ncdhw = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Z = arg.weight_.GetLengths()[3];
std::size_t Y = arg.weight_.GetLengths()[4];
std::size_t X = arg.weight_.GetLengths()[5];
std::size_t Do = arg.output_.mDesc.GetLengths()[2];
std::size_t Ho = arg.output_.mDesc.GetLengths()[3];
std::size_t Wo = arg.output_.mDesc.GetLengths()[4];
std::size_t Do = arg.output_.GetLengths()[3];
std::size_t Ho = arg.output_.GetLengths()[4];
std::size_t Wo = arg.output_.GetLengths()[5];
AccDataType v_acc = 0;
float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z)
{
auto d_tmp = ck::type_convert<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]);
auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0)
{
auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp =
ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[1]);
static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0)
{
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
auto ho =
static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp =
ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(
arg.in_left_pads_[2]) -
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[2]);
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0)
{
auto wo =
ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[2]);
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(
arg.conv_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
float v_out = 0;
float v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<AccDataType>(
arg.output_(
n, k, do_, ho, wo)));
ck::type_convert<float>(arg.output_(
g, n, k, do_, ho, wo)));
arg.wei_element_op_(
v_wei,
ck::type_convert<AccDataType>(
arg.weight_(k, c, z, y, x)));
ck::type_convert<float>(
arg.weight_(g, k, c, z, y, x)));
v_acc += v_out * v_wei;
}
......@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
AccDataType v_in;
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncdhw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3],
arg.input_.mDesc.GetLengths()[4])(
arg.input_.GetLengths()[0],
arg.input_.GetLengths()[1],
arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3],
arg.input_.GetLengths()[4],
arg.input_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
......
......@@ -7,21 +7,25 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template <typename InDataType,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator
{
// Argument
......@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg)
{
if constexpr(NumDimSpatial == 1)
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{
constexpr auto I0 = Number<0>{};
auto f_kcx = [&](auto k, auto c, auto x) {
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto f_kcx = [&](auto g, auto k, auto c, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo)
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[3]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
float v_out;
float v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(arg.output_(n, k, wo)));
arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, wi)));
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
v_acc += v_out * v_in;
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei);
arg.weight_(g, k, c, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2])(
arg.weight_.GetLengths()[0],
arg.weight_.GetLengths()[1],
arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
else if constexpr(NDimSpatial == 2)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho)
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo)
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
float v_out;
float v_in;
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo)));
v_out,
ck::type_convert<float>(arg.output_(g, n, k, ho, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
arg.weight_(g, k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])(
arg.weight_.GetLengths()[0],
arg.weight_.GetLengths()[1],
arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3],
arg.weight_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
else if constexpr(NDimSpatial == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
auto f_kczyx = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_)
for(std::size_t do_ = 0; do_ < arg.output_.GetLengths()[3]; ++do_)
{
auto di =
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
auto di = static_cast<ck::long_index_t>(do_ * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[4]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[5]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[I2]) +
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
arg.input_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
arg.input_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
arg.input_.GetLengths()[5])
{
float v_out;
float v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo)));
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
arg.output_(g, n, k, do_, ho, wo)));
arg.in_element_op_(v_in,
ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi)));
v_acc += v_out * v_in;
}
......@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
arg.weight_(g, k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])(
arg.weight_.GetLengths()[0],
arg.weight_.GetLengths()[1],
arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3],
arg.weight_.GetLengths()[4],
arg.weight_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
......
......@@ -8,7 +8,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
......@@ -17,9 +17,10 @@ namespace host {
//
// @brief Reference implementation for forward convolution.
//
// @paragraph Supports both NCHW as well as NHWC formats (and their respective
// counterparts for weight and output) as long as tensor descriptor
// lengths is in NCHW.
// @paragraph
// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
//
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
......@@ -28,16 +29,20 @@ namespace host {
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam NumDimSpatial Number of spatial dimensions.
// @tparam NDimSpatial Number of spatial dimensions.
//
template <typename InDataType,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator
{
// Argument
......@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator
float Run(const Argument& arg)
{
if constexpr(NumDimSpatial == 1)
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{
auto f_ncw = [&](auto n, auto k, auto wo) {
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
float v_in;
float v_wei;
arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, wi)));
arg.wei_element_op_(v_wei,
ck::type_convert<float>(arg.weight_(k, c, x)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_in * v_wei;
}
......@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_ncw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2])(
make_ParallelTensorFunctor(func,
arg.output_.GetLengths()[0],
arg.output_.GetLengths()[1],
arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
else if constexpr(NDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x)));
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
v_acc += v_in * v_wei;
}
}
......@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3])(
make_ParallelTensorFunctor(func,
arg.output_.GetLengths()[0],
arg.output_.GetLengths()[1],
arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
else if constexpr(NDimSpatial == 3)
{
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{
for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z)
{
auto di =
ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[2]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
arg.input_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
arg.input_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
arg.input_.GetLengths()[5])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
arg.in_element_op_(v_in,
ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi)));
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(arg.weight_(k, c, z, y, x)));
ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
v_acc += v_in * v_wei;
}
}
......@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3],
arg.output_.mDesc.GetLengths()[4])(
make_ParallelTensorFunctor(func,
arg.output_.GetLengths()[0],
arg.output_.GetLengths()[1],
arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4],
arg.output_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
......@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
bool IsSupportedArgument(const device::BaseArgument*) override
{
return NDimSpatial >= 1 && NDimSpatial <= 3;
}
static auto MakeArgument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
......
......@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
......
......@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
......
......@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
......
......@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment