Commit 5e88414a authored by Chao Liu's avatar Chao Liu
Browse files

use statically index array for all existing kernels

parent a578ff93
......@@ -107,8 +107,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t e_block_data_on_global = block_work_id[0] * EPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
const index_t e_block_data_on_global = block_work_id[Number<0>{}] * EPerBlock;
const index_t b_block_data_on_global = block_work_id[Number<1>{}] * BPerBlock;
// output tensor
// global tensor in global memory, src of blockwise copy
......@@ -151,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, b_block_data_on_global, 0}, {0, 0, 0});
make_multi_index(0, b_block_data_on_global, 0), make_multi_index(0, 0, 0));
// weight tensor
// global tensor in global memory, src of blockwise copy
......@@ -191,7 +191,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, e_block_data_on_global, 0}, {0, 0, 0});
make_multi_index(0, e_block_data_on_global, 0), make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -434,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
InThreadCopyDstDataPerWrite_B,
AddressSpace::Vgpr,
AddressSpace::Global,
in_memory_op>({0, 0, 0, 0, 0, 0},
{e_thread_data_on_global / E1,
e_thread_data_on_global % E1,
0,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1,
0})
in_memory_op>(make_multi_index(0, 0, 0, 0, 0, 0),
make_multi_index(e_thread_data_on_global / E1,
e_thread_data_on_global % E1,
0,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1,
0))
.Run(p_in_thread, p_in_global);
}
}
......
......@@ -125,7 +125,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
index_t GemmK1 = XDotSlice;
index_t GemmK2 = K;
return Array<index_t, 5>{GemmM, GemmN, GemmK0, GemmK1, GemmK2};
return make_multi_index(GemmM, GemmN, GemmK0, GemmK1, GemmK2);
}
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
......
......@@ -226,7 +226,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, k_block_data_on_global}, {0, 0});
make_multi_index(0, k_block_data_on_global), make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......
......@@ -242,8 +242,8 @@ struct DynamicTransformedTensorDescriptor
static_for<0, NTransform, 1>{}([&](auto itran) constexpr {
const auto tran = transforms_.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran));
const auto idx_up_part = pick_container_element(idx_up, UpDimensionIds{}.At(itran));
auto idx_low_part = pick_container_element(idx_low, LowDimensionIds{}.At(itran));
tran.CalculateLowerIndex(idx_low_part, idx_up_part);
});
......@@ -259,14 +259,16 @@ struct DynamicTransformedTensorDescriptor
const auto tran = transforms_.At(itran);
const auto idx_up_diff_part =
pick_array_element(idx_up_diff, UpDimensionIds{}.At(itran));
pick_container_element(idx_up_diff, UpDimensionIds{}.At(itran));
const auto idx_up_old_part = pick_array_element(idx_up_old, UpDimensionIds{}.At(itran));
const auto idx_up_old_part =
pick_container_element(idx_up_old, UpDimensionIds{}.At(itran));
const auto idx_low_old_part =
pick_array_element(idx_low_old, LowDimensionIds{}.At(itran));
pick_container_element(idx_low_old, LowDimensionIds{}.At(itran));
auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds{}.At(itran));
auto idx_low_diff_part =
pick_container_element(idx_low_diff, LowDimensionIds{}.At(itran));
tran.CalculateLowerIndexDiff(
idx_low_diff_part, idx_up_diff_part, idx_low_old_part, idx_up_old_part);
......@@ -325,7 +327,7 @@ struct DynamicTransformedTensorDescriptor
if constexpr(!is_valid_up_always_mapped_to_valid_low)
{
const auto up_dims_part = UpDimensionIds{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, up_dims_part);
const auto idx_up_part = pick_container_element(idx_up, up_dims_part);
flag = flag && tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up_part);
}
......
......@@ -140,7 +140,7 @@ struct DynamicTensorDescriptor_v2
MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index
auto idx_hidden_pick_visible = pick_array_element(idx_hidden, visible_dim_ids);
auto idx_hidden_pick_visible = pick_container_element(idx_hidden, visible_dim_ids);
idx_hidden_pick_visible = idx;
// calculate hidden index
......@@ -149,8 +149,8 @@ struct DynamicTensorDescriptor_v2
constexpr auto dims_low = GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = GetUpperDimensionIdss().At(itran);
const auto idx_up = pick_array_element(idx_hidden, dims_up);
auto idx_low = pick_array_element(idx_hidden, dims_low);
const auto idx_up = pick_container_element(idx_hidden, dims_up);
auto idx_low = pick_container_element(idx_hidden, dims_low);
tran.CalculateLowerIndex(idx_low, idx_up);
});
......@@ -193,7 +193,7 @@ struct DynamicTensorDescriptor_v2
constexpr auto up_dim_ids = UpperDimensionIdss{}.At(itran);
// lengths_hidden_pick_up contains a reference to lengths_hidden
auto hidden_lengths_pick_up = pick_array_element(hidden_lengths, up_dim_ids);
auto hidden_lengths_pick_up = pick_container_element(hidden_lengths, up_dim_ids);
hidden_lengths_pick_up = tran.GetUpperLengths();
});
......@@ -207,7 +207,7 @@ struct DynamicTensorDescriptor_v2
// variable lengths_) to save space on stack?
const HiddenIndex hidden_lengths_;
// visible_lenths_ contains a reference to hidden_lengths_
const ArrayElementPicker<const HiddenIndex, VisibleDimensionIds> visible_lengths_;
const ContainerElementPicker<const HiddenIndex, VisibleDimensionIds> visible_lengths_;
#if 0
// friend class
......@@ -283,7 +283,7 @@ struct DynamicTensorCoordinate_v2
// private member variables
HiddenIndex idx_hidden_;
// idx_visible_ contains a reference to idx_hidden_
ArrayElementPicker<HiddenIndex, VisibleDimensionIds> idx_visible_;
ContainerElementPicker<HiddenIndex, VisibleDimensionIds> idx_visible_;
#if 0
// friend functions for making and updating tensor coordinate
......@@ -441,7 +441,7 @@ make_dynamic_tensor_coordinate_v2(const TensorDesc& tensor_desc, const VisibleIn
MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index
auto idx_hidden_pick_visible = pick_array_element(idx_hidden, visible_dim_ids);
auto idx_hidden_pick_visible = pick_container_element(idx_hidden, visible_dim_ids);
idx_hidden_pick_visible = idx_visible;
// calculate hidden index
......@@ -451,8 +451,8 @@ make_dynamic_tensor_coordinate_v2(const TensorDesc& tensor_desc, const VisibleIn
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto idx_up = pick_array_element(idx_hidden, dims_up);
auto idx_low = pick_array_element(idx_hidden, dims_low);
const auto idx_up = pick_container_element(idx_hidden, dims_up);
auto idx_low = pick_container_element(idx_hidden, dims_low);
tran.CalculateLowerIndex(idx_low, idx_up);
});
......@@ -477,7 +477,7 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
Array<bool, ndim_hidden> non_zero_diff{false};
auto non_zero_diff_pick_visible = pick_array_element(non_zero_diff, visible_dim_ids);
auto non_zero_diff_pick_visible = pick_container_element(non_zero_diff, visible_dim_ids);
static_for<0, ndim_visible, 1>{}([&non_zero_diff_pick_visible, &idx_diff_visible](auto i) {
non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0);
......@@ -487,8 +487,8 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto non_zero_diff_pick_up = pick_array_element(non_zero_diff, dims_up);
auto non_zero_diff_pick_low = pick_array_element(non_zero_diff, dims_low);
const auto non_zero_diff_pick_up = pick_container_element(non_zero_diff, dims_up);
auto non_zero_diff_pick_low = pick_container_element(non_zero_diff, dims_low);
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
......@@ -526,7 +526,7 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
// initialize visible index diff
// idx_diff_hidden_pick_visible contains reference to idx_diff_hidden
auto idx_diff_hidden_pick_visible =
pick_array_element(idx_diff_hidden, TensorDesc::GetVisibleDimensionIds());
pick_container_element(idx_diff_hidden, TensorDesc::GetVisibleDimensionIds());
idx_diff_hidden_pick_visible = coord_step.GetVisibleIndexDiff();
......@@ -535,7 +535,7 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
// update visible index
auto idx_hidden_pick_visible =
pick_array_element(idx_hidden, TensorDesc::GetVisibleDimensionIds());
pick_container_element(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_step.GetIndexDiff();
// update rest of hidden index
......@@ -546,12 +546,12 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
// this const is for ArrayElementPicker, Array itself may not be const
const auto idx_up = pick_array_element(idx_hidden, dims_up);
auto idx_low = pick_array_element(idx_hidden, dims_low);
// this const is for ContainerElementPicker, Array itself may not be const
const auto idx_up = pick_container_element(idx_hidden, dims_up);
auto idx_low = pick_container_element(idx_hidden, dims_low);
const auto idx_diff_up = pick_array_element(idx_diff_hidden, dims_up);
auto idx_diff_low = pick_array_element(idx_diff_hidden, dims_low);
const auto idx_diff_up = pick_container_element(idx_diff_hidden, dims_up);
auto idx_diff_low = pick_container_element(idx_diff_hidden, dims_low);
tran.CalculateLowerIndexDiff(idx_diff_low, idx_diff_up, idx_low, idx_up);
......@@ -579,7 +579,7 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex())
{
const auto idx_up =
pick_array_element(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
pick_container_element(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
valid = valid && tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
}
......
......@@ -311,8 +311,8 @@ struct TransformedTensorDescriptor
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran));
const auto idx_up_part = pick_container_element(idx_up, UpDimensionIds{}.At(itran));
auto idx_low_part = pick_container_element(idx_low, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
......@@ -333,14 +333,16 @@ struct TransformedTensorDescriptor
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_diff_part =
pick_array_element(idx_up_diff, UpDimensionIds{}.At(itran));
pick_container_element(idx_up_diff, UpDimensionIds{}.At(itran));
const auto idx_up_old_part = pick_array_element(idx_up_old, UpDimensionIds{}.At(itran));
const auto idx_up_old_part =
pick_container_element(idx_up_old, UpDimensionIds{}.At(itran));
const auto idx_low_old_part =
pick_array_element(idx_low_old, LowDimensionIds{}.At(itran));
pick_container_element(idx_low_old, LowDimensionIds{}.At(itran));
auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds{}.At(itran));
auto idx_low_diff_part =
pick_container_element(idx_low_diff, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
......@@ -508,7 +510,7 @@ struct TransformedTensorDescriptor
constexpr auto low_lengths_part =
GetLowerTensorDescriptor().GetLengths(low_dims_part);
const auto idx_low_part =
to_multi_index(pick_array_element(idx_low, low_dims_part));
to_multi_index(pick_container_element(idx_low, low_dims_part));
static_for<0, decltype(low_dims_part)::Size(), 1>{}([&](auto i) {
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i];
......
......@@ -116,8 +116,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
const index_t m_block_data_on_global = block_work_id[Number<0>{}] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[Number<1>{}] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -143,7 +143,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, m_block_data_on_global}, {0, 0});
make_multi_index(0, m_block_data_on_global), make_multi_index(0, 0));
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -169,7 +169,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, n_block_data_on_global}, {0, 0});
make_multi_index(0, n_block_data_on_global), make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -355,11 +355,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1})
make_multi_index(0, 0, 0, 0),
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(p_c_thread, p_c_global);
}
}
......@@ -447,21 +447,23 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_k0_k1_k2_m_global_desc = AGlobalDesc{};
constexpr auto b_k0_k1_k2_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[0];
constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[1];
constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[2];
constexpr auto M = c_m_n_global_desc.GetLengths()[0];
constexpr auto N = c_m_n_global_desc.GetLengths()[1];
constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[I0];
constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[I1];
constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[I2];
constexpr auto M = c_m_n_global_desc.GetLengths()[I0];
constexpr auto N = c_m_n_global_desc.GetLengths()[I1];
// don't do anything if K == 0
if(K == 0)
......@@ -487,8 +489,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
const index_t m_block_data_on_global = block_work_id[I0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[I1] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -514,7 +516,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, 0, 0, m_block_data_on_global}, {0, 0, 0, 0});
make_multi_index(0, 0, 0, m_block_data_on_global), make_multi_index(0, 0, 0, 0));
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -540,7 +542,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, 0, 0, n_block_data_on_global}, {0, 0, 0, 0});
make_multi_index(0, 0, 0, n_block_data_on_global), make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -750,11 +752,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1})
make_multi_index(0, 0, 0, 0),
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(p_c_thread, p_c_global);
}
}
......
......@@ -2,9 +2,9 @@
#define CK_COMMON_HEADER_HPP
#include "array.hpp"
#include "array_helper.hpp"
#include "container_helper.hpp"
#include "statically_indexed_array.hpp"
#include "array_element_picker.hpp"
#include "container_element_picker.hpp"
#include "config.hpp"
#include "float_type.hpp"
#include "functional.hpp"
......
#ifndef CK_ARRAY_ELEMENT_PICKER_HPP
#define CK_ARRAY_ELEMENT_PICKER_HPP
#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP
#define CK_CONTAINER_ELEMENT_PICKER_HPP
#include "functional2.hpp"
#include "sequence.hpp"
......@@ -9,16 +9,16 @@ namespace ck {
// Arr: Array or StaticallyIndexedArray
// Picks: Sequence<...>
template <typename Arr, typename Picks>
struct ArrayElementPicker
struct ContainerElementPicker
{
using type = ArrayElementPicker;
using type = ContainerElementPicker;
#if 0
using data_type = typename Arr::data_type;
#endif
__host__ __device__ constexpr ArrayElementPicker() = delete;
__host__ __device__ constexpr ContainerElementPicker() = delete;
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
__host__ __device__ explicit constexpr ContainerElementPicker(Arr& array) : mArray{array}
{
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
......@@ -72,9 +72,9 @@ struct ArrayElementPicker
};
template <typename Arr, typename Picks, typename X>
__host__ __device__ constexpr auto operator+=(ArrayElementPicker<Arr, Picks>& y, const X& x)
__host__ __device__ constexpr auto operator+=(ContainerElementPicker<Arr, Picks>& y, const X& x)
{
using Y = ArrayElementPicker<Arr, Picks>;
using Y = ContainerElementPicker<Arr, Picks>;
constexpr index_t nsize = Y::Size();
static_assert(nsize == X::Size(), "wrong! size not the same");
......@@ -85,9 +85,9 @@ __host__ __device__ constexpr auto operator+=(ArrayElementPicker<Arr, Picks>& y,
}
template <typename Arr, typename Picks, typename X>
__host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y, const X& x)
__host__ __device__ constexpr auto operator-=(ContainerElementPicker<Arr, Picks>& y, const X& x)
{
using Y = ArrayElementPicker<Arr, Picks>;
using Y = ContainerElementPicker<Arr, Picks>;
constexpr index_t nsize = Y::Size();
static_assert(nsize == X::Size(), "wrong! size not the same");
......@@ -98,9 +98,9 @@ __host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y,
}
template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
__host__ __device__ constexpr auto pick_container_element(Arr& a, Picks)
{
return ArrayElementPicker<Arr, Picks>(a);
return ContainerElementPicker<Arr, Picks>(a);
}
} // namespace ck
......
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "array.hpp"
#include "array_helper.hpp"
#include "tuple.hpp"
#include "tuple_helper.hpp"
#include "statically_indexed_array.hpp"
#include "array_element_picker.hpp"
#include "container_element_picker.hpp"
namespace ck {
......@@ -24,6 +23,18 @@ __host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>
return r;
}
template <typename... Ts, typename T>
__host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x)
{
Tuple<Ts..., T> r;
static_for<0, sizeof...(Ts), 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r(Number<sizeof...(Ts)>{}) = x;
return r;
}
template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto
container_reorder_given_new2old(const Array<TData, NSize>& old_array, Sequence<IRs...> /*new2old*/)
......
#ifndef CK_FUNCTIONAL3_HPP
#define CK_FUNCTIONAL3_HPP
#include "array.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "sequence.hpp"
#include "multi_index.hpp"
namespace ck {
......@@ -133,7 +133,7 @@ struct ford
for(index_t i = 0; i < ordered_lengths.Front(); ++i)
{
detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
Array<index_t, 1>{i});
make_multi_index(i));
}
}
};
......
......@@ -3,7 +3,7 @@
#include "array.hpp"
#include "statically_indexed_array.hpp"
#include "array_helper.hpp"
#include "container_helper.hpp"
#include "sequence.hpp"
namespace ck {
......
......@@ -19,9 +19,11 @@ struct TupleElement
{
__host__ __device__ explicit constexpr TupleElement() : mData() {}
#if 0
__host__ __device__ explicit constexpr TupleElement(const TupleElement&) = default;
__host__ __device__ explicit constexpr TupleElement(TupleElement&&) = default;
#endif
template <typename UData>
__host__ __device__ explicit constexpr TupleElement(const TupleElement<Key, UData>& te)
......@@ -73,9 +75,11 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
static_assert(sizeof...(Is) == sizeof...(Xs), "wrong! inconsistent size");
}
#if 0
__host__ __device__ explicit constexpr TupleImpl(const TupleImpl&) = default;
__host__ __device__ explicit constexpr TupleImpl(TupleImpl&&) = default;
#endif
template <index_t... Js, typename... Ys>
__host__ __device__ explicit constexpr TupleImpl(const TupleImpl<Sequence<Js...>, Ys...>& y)
......@@ -124,9 +128,11 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ explicit constexpr Tuple() : base() {}
#if 0
__host__ __device__ constexpr Tuple(const Tuple&) = default;
__host__ __device__ constexpr Tuple(Tuple&&) = default;
#endif
template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs), bool>::type = false>
......
#ifndef CK_TUPLE_HELPER_HPP
#define CK_TUPLE_HELPER_HPP
#include "tuple_helper.hpp"
#include "functional4.hpp"
#include "tuple.hpp"
namespace ck {
......
......@@ -222,7 +222,7 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc i
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
constexpr index_t gemm_k2 = gemm_sizes.At(4);
constexpr index_t gemm_k2 = gemm_sizes[Number<4>{}];
constexpr bool is_gemm_not_empty = gemm_k2 > 0;
// only compile and run if GEMM is no empty
......
......@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 1
#elif 0
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
......
......@@ -561,7 +561,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 0
#elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment