Commit c15a3c09 authored by Chao Liu's avatar Chao Liu
Browse files

refactor most loop to static

parent 9ae20fc2
...@@ -439,12 +439,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -439,12 +439,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
1, 1,
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Global, AddressSpace::Global,
InMemoryDataOperation::Set>({0, 0, 0, 0, 0}, InMemoryDataOperation::Set>(make_multi_index(0, 0, 0, 0, 0),
{k_thread_data_on_global / K1, make_multi_index(k_thread_data_on_global / K1,
k_thread_data_on_global % K1, k_thread_data_on_global % K1,
0, 0,
b_thread_data_on_global, b_thread_data_on_global,
0}) 0))
.Run(p_out_thread, p_out_global); .Run(p_out_thread, p_out_global);
} }
} }
......
...@@ -350,7 +350,7 @@ struct UnMerge ...@@ -350,7 +350,7 @@ struct UnMerge
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{ {
LowerIndex idx_low{{0}}; LowerIndex idx_low = make_multi_index(0);
constexpr auto pseudo_up_strides = constexpr auto pseudo_up_strides =
reverse_inclusive_scan_sequence( reverse_inclusive_scan_sequence(
...@@ -358,7 +358,7 @@ struct UnMerge ...@@ -358,7 +358,7 @@ struct UnMerge
.PushBack(Number<1>{}); .PushBack(Number<1>{});
static_for<0, nDimUp, 1>{}( static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; }); [&](auto idim) { idx_low(Number<0>{}) += idx_up[idim] * pseudo_up_strides[idim]; });
return idx_low; return idx_low;
} }
...@@ -459,25 +459,17 @@ struct Embed ...@@ -459,25 +459,17 @@ struct Embed
index_t itmp = icorner; index_t itmp = icorner;
#if 0
for(index_t idim = nDimUp - 1; idim >= 0; --idim)
{
idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1;
itmp /= 2;
}
#else
static_for<nDimUp, 0, -1>{}([&](auto idim) { static_for<nDimUp, 0, -1>{}([&](auto idim) {
auto idim_m1 = idim - Number<1>{}; auto idim_m1 = idim - Number<1>{};
idx_up(idim_m1) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim_m1) - 1; idx_up(idim_m1) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim_m1) - 1;
itmp /= 2; itmp /= 2;
}); });
#endif
// calculate lower index // calculate lower index
auto idx_low = CalculateLowerIndex(idx_up); auto idx_low = CalculateLowerIndex(idx_up);
// judge if lower index is valid // judge if lower index is valid
flag = flag && idx_low[0] >= 0 && idx_low[0] < LowerLength; flag = flag && idx_low[Number<0>{}] >= 0 && idx_low[Number<0>{}] < LowerLength;
} }
return flag; return flag;
......
...@@ -95,15 +95,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -95,15 +95,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
#if 1 #if 1
// zero out buffer // zero out buffer
for(index_t i = 0; i < long_vector_size; ++i) static_for<0, long_vector_size, 1>{}([&](auto i) { p_src_long_vector[i] = 0; });
{
p_src_long_vector[i] = 0;
}
#endif #endif
// load data from src to the long-vector buffer // load data from src to the long-vector buffer
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) {
{
auto scalar_id = make_zero_multi_index<nDim>(); auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access; scalar_id(vector_access_dim) = i * src_data_per_access;
...@@ -130,19 +126,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -130,19 +126,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
buffer_offset, buffer_offset,
true, true,
long_vector_size); long_vector_size);
} });
// SrcData to DstData conversion // SrcData to DstData conversion
DstData p_dst_long_vector[long_vector_size]; DstData p_dst_long_vector[long_vector_size];
for(index_t i = 0; i < long_vector_size; ++i) static_for<0, long_vector_size, 1>{}([&](auto i) {
{
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]); p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
} });
// store data from the long-vector buffer to dst // store data from the long-vector buffer to dst
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i) static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) {
{
auto scalar_id = make_zero_multi_index<nDim>(); auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access; scalar_id(vector_access_dim) = i * dst_data_per_access;
...@@ -169,7 +163,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -169,7 +163,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
dst_coord.GetOffset(), dst_coord.GetOffset(),
dst_coord.IsOffsetValidAssumingUpperIndexIsValid(), dst_coord.IsOffsetValidAssumingUpperIndexIsValid(),
DstDesc::GetElementSpace()); DstDesc::GetElementSpace());
} });
}); });
} }
......
...@@ -496,10 +496,10 @@ int main(int argc, char* argv[]) ...@@ -496,10 +496,10 @@ int main(int argc, char* argv[])
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
print_array("LeftPads", LeftPads{}); print_array("LeftPads", to_multi_index(LeftPads{}));
print_array("RightPads", RightPads{}); print_array("RightPads", to_multi_index(RightPads{}));
print_array("ConvStrides", ConvStrides{}); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", ConvDilations{}); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 1
using in_data_t = float; using in_data_t = float;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment