Commit 9ae20fc2 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 34cbbb48
...@@ -10,9 +10,9 @@ template <index_t N> ...@@ -10,9 +10,9 @@ template <index_t N>
using MultiIndex = Array<index_t, N>; using MultiIndex = Array<index_t, N>;
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(const Xs&... xs) __host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
{ {
return make_array<const index_t>(std::forward<const Xs>(xs)...); return make_array<index_t>(index_t{xs}...);
} }
template <index_t NSize> template <index_t NSize>
...@@ -79,9 +79,9 @@ template <index_t N> ...@@ -79,9 +79,9 @@ template <index_t N>
using MultiIndex = StaticallyIndexedArray<index_t, N>; using MultiIndex = StaticallyIndexedArray<index_t, N>;
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(const Xs&... xs) __host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
{ {
return make_statically_indexed_array<const index_t>(std::forward<const Xs>(xs)...); return make_statically_indexed_array<index_t>(index_t{xs}...);
} }
template <index_t NSize> template <index_t NSize>
......
...@@ -232,7 +232,7 @@ struct Merge ...@@ -232,7 +232,7 @@ struct Merge
const UpperIndex& /* idx_up_old */, const UpperIndex& /* idx_up_old */,
const LowerIndex& idx_low_old) const LowerIndex& idx_low_old)
{ {
if(idx_up_diff[0] == 0) if(idx_up_diff[Number<0>{}] == 0)
{ {
return make_zero_multi_index<nDimLow>(); return make_zero_multi_index<nDimLow>();
} }
...@@ -257,7 +257,7 @@ struct Merge ...@@ -257,7 +257,7 @@ struct Merge
LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp; LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp;
if(idx_up_diff[0] > 0) if(idx_up_diff[Number<0>{}] > 0)
{ {
// do carry check on each low dimension in reversed order // do carry check on each low dimension in reversed order
// starting from the first digit that changed // starting from the first digit that changed
...@@ -285,7 +285,7 @@ struct Merge ...@@ -285,7 +285,7 @@ struct Merge
// highest dimension, no out-of-bound check // highest dimension, no out-of-bound check
if(carry) if(carry)
{ {
++idx_low_new(0); ++idx_low_new(Number<0>{});
} }
} }
else else
...@@ -316,7 +316,7 @@ struct Merge ...@@ -316,7 +316,7 @@ struct Merge
// highest dimension, no out-of-bound check // highest dimension, no out-of-bound check
if(borrow) if(borrow)
{ {
--idx_low_new(0); --idx_low_new(Number<0>{});
} }
} }
...@@ -411,12 +411,10 @@ struct Embed ...@@ -411,12 +411,10 @@ struct Embed
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{ {
LowerIndex idx_low{{Coefficients{}[nDimUp]}}; LowerIndex idx_low = make_multi_index(Coefficients{}[Number<nDimUp>{}]);
for(index_t i = 0; i < nDimUp; ++i) static_for<0, nDimUp, 1>{}(
{ [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * Coefficients{}[i]; });
idx_low(0) += idx_up[i] * Coefficients{}[i];
}
return idx_low; return idx_low;
} }
...@@ -426,12 +424,10 @@ struct Embed ...@@ -426,12 +424,10 @@ struct Embed
const UpperIndex& /* idx_up_old */, const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */) const LowerIndex& /* idx_low_old */)
{ {
LowerIndex idx_low_diff{0}; LowerIndex idx_low_diff = make_multi_index(0);
for(index_t i = 0; i < nDimUp; ++i) static_for<0, nDimUp, 1>{}(
{ [&](auto i) { idx_low_diff(Number<0>{}) += idx_up_diff[i] * Coefficients{}[i]; });
idx_low_diff(0) += idx_up_diff[i] * Coefficients{}[i];
}
return idx_low_diff; return idx_low_diff;
} }
...@@ -463,11 +459,19 @@ struct Embed ...@@ -463,11 +459,19 @@ struct Embed
index_t itmp = icorner; index_t itmp = icorner;
#if 0
for(index_t idim = nDimUp - 1; idim >= 0; --idim) for(index_t idim = nDimUp - 1; idim >= 0; --idim)
{ {
idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1; idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1;
itmp /= 2; itmp /= 2;
} }
#else
static_for<nDimUp, 0, -1>{}([&](auto idim) {
auto idim_m1 = idim - Number<1>{};
idx_up(idim_m1) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim_m1) - 1;
itmp /= 2;
});
#endif
// calculate lower index // calculate lower index
auto idx_low = CalculateLowerIndex(idx_up); auto idx_low = CalculateLowerIndex(idx_up);
...@@ -504,7 +508,7 @@ struct Freeze ...@@ -504,7 +508,7 @@ struct Freeze
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& /*idx_up*/) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& /*idx_up*/)
{ {
return to_array(LowerFreezePoint{}); return to_multi_index(LowerFreezePoint{});
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
......
...@@ -510,10 +510,9 @@ struct TransformedTensorDescriptor ...@@ -510,10 +510,9 @@ struct TransformedTensorDescriptor
const auto idx_low_part = const auto idx_low_part =
to_multi_index(pick_array_element(idx_low, low_dims_part)); to_multi_index(pick_array_element(idx_low, low_dims_part));
for(index_t i = 0; i < low_dims_part.Size(); ++i) static_for<0, decltype(low_dims_part)::Size(), 1>{}([&](auto i) {
{
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i]; flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i];
} });
} }
}); });
......
...@@ -17,13 +17,19 @@ __host__ __device__ void print_array(const char* s, T a) ...@@ -17,13 +17,19 @@ __host__ __device__ void print_array(const char* s, T a)
if constexpr(is_same<data_type, uint32_t>{}) if constexpr(is_same<data_type, uint32_t>{})
{ {
printf("%s size %u, {", s, nsize); printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", a[i]); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); });
printf("}\n"); printf("}\n");
} }
else if constexpr(true) else if constexpr(is_same<data_type, int32_t>{})
{ {
printf("%s size %d, {", s, nsize); printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", a[i]); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, bool>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); });
printf("}\n"); printf("}\n");
} }
} }
...@@ -40,7 +46,7 @@ __host__ __device__ void print_array_v2(const char* s, T a) ...@@ -40,7 +46,7 @@ __host__ __device__ void print_array_v2(const char* s, T a)
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
printf("}\n"); printf("}\n");
} }
else if constexpr(true) else if constexpr(is_same<data_type, int32_t>{})
{ {
printf("%s size %d, {", s, nsize); printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
......
...@@ -96,7 +96,7 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -96,7 +96,7 @@ void device_dummy_dynamic_transform_v2(InDesc,
for(index_t j = 0; j < nrepeat; ++j) for(index_t j = 0; j < nrepeat; ++j)
{ {
#if 1 #if 0
launch_kernel(run_gridwise_operation<DummyDynamicTransform_v2_1<BlockSize>, launch_kernel(run_gridwise_operation<DummyDynamicTransform_v2_1<BlockSize>,
index_t* const, index_t* const,
float* const, float* const,
......
...@@ -4,10 +4,7 @@ ...@@ -4,10 +4,7 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include "config.hpp" #include "config.hpp"
#include "tensor_descriptor.hpp" #include "print.hpp"
#include "tensor_descriptor_helper.hpp"
#include "print_array.hpp"
#include "print_sequence.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
...@@ -211,11 +208,11 @@ int main(int argc, char* argv[]) ...@@ -211,11 +208,11 @@ int main(int argc, char* argv[])
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
print_sequence("LeftPads", LeftPads{}); print_array("LeftPads", LeftPads{});
print_sequence("LeftPads", LeftPads{}); print_array("LeftPads", LeftPads{});
print_sequence("RightPads", RightPads{}); print_array("RightPads", RightPads{});
print_sequence("ConvStrides", ConvStrides{}); print_array("ConvStrides", ConvStrides{});
print_sequence("ConvDilations", ConvDilations{}); print_array("ConvDilations", ConvDilations{});
Tensor<float> in_nchw_device(make_HostTensorDescriptor(in_nchw_desc)); Tensor<float> in_nchw_device(make_HostTensorDescriptor(in_nchw_desc));
Tensor<float> in_nchw_host(make_HostTensorDescriptor(in_nchw_desc)); Tensor<float> in_nchw_host(make_HostTensorDescriptor(in_nchw_desc));
...@@ -248,7 +245,7 @@ int main(int argc, char* argv[]) ...@@ -248,7 +245,7 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0 #elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1 #elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
......
...@@ -573,7 +573,7 @@ int main(int argc, char* argv[]) ...@@ -573,7 +573,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dummy_static_transform(in_nchw_desc, device_dummy_static_transform(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment