Commit f16356d4 authored by Chao Liu's avatar Chao Liu
Browse files

adding n-D RunWrite for threadwise copy v3

parent 1e347018
...@@ -269,7 +269,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -269,7 +269,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
// hack to control index calculation when iterating over b_k_n_global tensor // hack to control index calculation when iterating over b_k_n_global tensor
#if 0 #if 1
// for padded input // for padded input
constexpr auto b_k_n_global_iterator_hacks = constexpr auto b_k_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
......
...@@ -686,6 +686,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -686,6 +686,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
} }
#endif #endif
#if 0
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst) __device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
{ {
static_assert(remove_reference_t<DstDesc>::GetNumOfDimension() == 2, static_assert(remove_reference_t<DstDesc>::GetNumOfDimension() == 2,
...@@ -762,6 +763,232 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -762,6 +763,232 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator); move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator);
} }
} }
#else
template <typename DstIteratorHacks>
__device__ void
RunWrite(const DstDesc& dst_desc, DstData* p_dst, const DstIteratorHacks& dst_iterator_hacks)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dst_dim_access_order);
#if 0
// make forward iterators
const auto dst_forward_iterators = generate_tuple(
[&](auto i) {
Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator(
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
return forward_iterator;
},
Number<nDim>{});
// make backward iterators
const auto dst_backward_iterators = generate_tuple(
[&](auto i) {
Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator(
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
return backward_iterator;
},
Number<nDim>{});
#elif 0
const auto dst_forward_iterators = make_tuple(
make_dynamic_tensor_coordinate_iterator(dst_desc,
make_multi_index(1, 0) * dst_scalar_per_access,
dst_iterator_hacks[I0][I0]),
make_dynamic_tensor_coordinate_iterator(dst_desc,
make_multi_index(0, 1) * dst_scalar_per_access,
dst_iterator_hacks[I0][I1]));
const auto dst_backward_iterators = make_tuple(
make_dynamic_tensor_coordinate_iterator(dst_desc,
make_multi_index(-1, 0) * dst_scalar_per_acces,
dst_iterator_hacks[I1][I0]),
make_dynamic_tensor_coordinate_iterator(dst_desc,
make_multi_index(0, -1) * dst_scalar_per_acces,
dst_iterator_hacks[I1][I1]));
#else
const auto tmp0 = make_dynamic_tensor_coordinate_iterator(
dst_desc, make_multi_index(1, 0) * dst_scalar_per_access, dst_iterator_hacks[I0][I0]);
const auto tmp1 = make_dynamic_tensor_coordinate_iterator(
dst_desc, make_multi_index(0, 1) * dst_scalar_per_access, dst_iterator_hacks[I0][I1]);
const auto dst_forward_iterators = make_tuple(tmp0, tmp1);
const auto tmp2 = make_dynamic_tensor_coordinate_iterator(
dst_desc, make_multi_index(-1, 0) * dst_scalar_per_access, dst_iterator_hacks[I1][I0]);
const auto tmp3 = make_dynamic_tensor_coordinate_iterator(
dst_desc, make_multi_index(0, -1) * dst_scalar_per_access, dst_iterator_hacks[I1][I1]);
const auto dst_backward_iterators = make_tuple(tmp2, tmp3);
#endif
// loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_access_idx[i]
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
});
auto dst_data_idx =
container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
return dst_data_idx;
}();
// copy data
// hardcoding for ds_write
// TODO refactor transfer_data() to encapsulate this
static_assert(DstAddressSpace == AddressSpace::Lds &&
DstInMemOp == InMemoryDataOperation::Set,
"wrong! hardcoded for ds_write");
vector_type<DstData, DstScalarPerVector> dst_vector;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector(i) = buffer_[Number<buffer_offset>{}];
});
using DstVectorType = typename vector_type<DstData, DstScalarPerVector>::MemoryType;
*reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
dst_vector.Vector();
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
});
});
return move_on_dim;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_dynamic_tensor_coordinate(
dst_desc,
dst_slice_origin_coord_,
dst_forward_iterators[dst_dim_access_order[i]]);
}
else
{
move_dynamic_tensor_coordinate(
dst_desc,
dst_slice_origin_coord_,
dst_backward_iterators[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_iterator =
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator);
}
}
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
constexpr auto seq = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
#if 1
constexpr auto dst_iterator_hacks = make_tuple(make_tuple(seq, seq), make_tuple(seq, seq));
#elif 0
constexpr auto dst_iterator_hacks = make_tuple(make_tuple(Sequence<0>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<0>{}));
#elif 1
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
const auto dst_forward_iterators =
make_tuple(make_dynamic_tensor_coordinate_iterator(
DstDesc{}, make_multi_index(1, 0) * dst_scalar_per_access),
make_dynamic_tensor_coordinate_iterator(
DstDesc{}, make_multi_index(0, 1) * dst_scalar_per_access));
const auto dst_backward_iterators =
make_tuple(make_dynamic_tensor_coordinate_iterator(
dst_desc, make_multi_index(-1, 0) * dst_scalar_per_access),
make_dynamic_tensor_coordinate_iterator(
dst_desc, make_multi_index(0, -1) * dst_scalar_per_access));
#endif
RunWrite(dst_desc, p_dst, dst_iterator_hacks);
}
#endif
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
{ {
......
...@@ -233,7 +233,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -233,7 +233,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto conv_driver = constexpr auto conv_driver =
#if 0 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#elif 1 #elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
......
...@@ -37,7 +37,7 @@ int main(int argc, char* argv[]) ...@@ -37,7 +37,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 35x35, stride 2 // 3x3, 35x35, stride 2
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -67,7 +67,7 @@ int main(int argc, char* argv[]) ...@@ -67,7 +67,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
// 1x1, 8x8 // 1x1, 8x8
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1536; constexpr index_t C = 1536;
...@@ -127,13 +127,13 @@ int main(int argc, char* argv[]) ...@@ -127,13 +127,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
// 7x1, 17x17 // 7x1, 17x17
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 384; constexpr index_t K = 128;
constexpr index_t Y = 7; constexpr index_t Y = 7;
constexpr index_t X = 1; constexpr index_t X = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment