Commit 0d475c27 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 1c704471
...@@ -70,6 +70,14 @@ struct DynamicPassThrough ...@@ -70,6 +70,14 @@ struct DynamicPassThrough
{ {
return true; return true;
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicPassThrough, ");
print_multi_index(up_lengths_);
printf("}");
}
}; };
template <bool SkipIsValidCheck = false> template <bool SkipIsValidCheck = false>
...@@ -145,6 +153,17 @@ struct DynamicPad ...@@ -145,6 +153,17 @@ struct DynamicPad
return SkipIsValidCheck || ((idx_up[Number<0>{}] >= left_pad_) && return SkipIsValidCheck || ((idx_up[Number<0>{}] >= left_pad_) &&
(idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_)); (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_));
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicPad, ");
print_multi_index(up_lengths_);
printf("left_pad_ %d", left_pad_);
printf(", ");
printf("right_pad_ %d", right_pad_);
printf("}");
}
}; };
template <bool SkipIsValidCheck = false> template <bool SkipIsValidCheck = false>
...@@ -214,6 +233,15 @@ struct DynamicLeftPad ...@@ -214,6 +233,15 @@ struct DynamicLeftPad
{ {
return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_); return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_);
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicLeftPad, ");
print_multi_index(up_lengths_);
printf("left_pad_ %d", left_pad_);
printf("}");
}
}; };
template <bool SkipIsValidCheck = false> template <bool SkipIsValidCheck = false>
...@@ -287,6 +315,15 @@ struct DynamicRightPad ...@@ -287,6 +315,15 @@ struct DynamicRightPad
{ {
return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_); return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_);
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicRightPad, ");
print_multi_index(up_lengths_);
printf("left_pad_ %d", right_pad_);
printf("}");
}
}; };
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
...@@ -364,6 +401,17 @@ struct DynamicEmbed ...@@ -364,6 +401,17 @@ struct DynamicEmbed
{ {
return true; return true;
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicEmbed, ");
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("coefficients_ ");
print_multi_index(coefficients_);
printf("}");
}
}; };
template <index_t NDimLow> template <index_t NDimLow>
...@@ -859,7 +907,20 @@ struct DynamicMerge ...@@ -859,7 +907,20 @@ struct DynamicMerge
{ {
return true; return true;
} }
}; // namespace ck
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicMerge, ");
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_scan_ ");
print_multi_index(low_lengths_scan_);
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("}");
}
};
template <index_t NDimUp, bool Use24BitIntegerCalculation = false> template <index_t NDimUp, bool Use24BitIntegerCalculation = false>
struct DynamicUnMerge struct DynamicUnMerge
...@@ -938,6 +999,15 @@ struct DynamicUnMerge ...@@ -938,6 +999,15 @@ struct DynamicUnMerge
{ {
return true; return true;
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicUnMerge, ");
print_multi_index(up_lengths_);
print_multi_index(up_lengths_scan_);
printf("}");
}
}; };
struct DynamicFreeze struct DynamicFreeze
...@@ -997,6 +1067,8 @@ struct DynamicFreeze ...@@ -997,6 +1067,8 @@ struct DynamicFreeze
{ {
return true; return true;
} }
__host__ __device__ void Print() const { printf("DynamicFreeze"); }
}; };
} // namespace ck } // namespace ck
......
...@@ -146,6 +146,23 @@ struct DynamicTensorDescriptor ...@@ -146,6 +146,23 @@ struct DynamicTensorDescriptor
return hidden_lengths; return hidden_lengths;
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicTensorDescriptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionIds:");
LowerDimensionIdss{}.At(i).Print();
printf("UpperDimensionIds:");
UpperDimensionIdss{}.At(i).Print();
});
printf("}");
VisibleDimensionIds::Print();
}
// TODO make these private // TODO make these private
Transforms transforms_; Transforms transforms_;
// TODO maybe hidden_lengths_ should use reference_wrapper (reference to transforms_'s member // TODO maybe hidden_lengths_ should use reference_wrapper (reference to transforms_'s member
......
...@@ -163,6 +163,16 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x) ...@@ -163,6 +163,16 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return r; return r;
} }
template <typename... Xs>
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
{
printf("{");
printf("MultiIndex, ");
printf("size %d,", index_t{sizeof...(Xs)});
static_for<0, sizeof...(Xs), 1>{}([&](auto i) { printf("%d ", x.At(i)); });
printf("}");
}
#endif #endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -278,7 +278,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -278,7 +278,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
#elif 1 #elif 0
// for non-padded input // for non-padded input
constexpr auto b_k_n_global_iterator_hacks = make_tuple( constexpr auto b_k_n_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}), make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}),
......
...@@ -168,6 +168,14 @@ struct Sequence ...@@ -168,6 +168,14 @@ struct Sequence
{ {
return Sequence<f(Is)...>{}; return Sequence<f(Is)...>{};
} }
__host__ __device__ static void Print()
{
printf("{");
printf("size %d, ", index_t{Size()});
static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
printf("}");
}
}; };
// merge sequence // merge sequence
......
...@@ -235,7 +235,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -235,7 +235,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr auto conv_driver = constexpr auto conv_driver =
#if 1 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#elif 1 #elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1 #elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
......
...@@ -67,7 +67,7 @@ int main(int argc, char* argv[]) ...@@ -67,7 +67,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
// 1x1, 8x8 // 1x1, 8x8
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1536; constexpr index_t C = 1536;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment