Commit 4d13badd authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent b23e7f8e
...@@ -156,7 +156,7 @@ struct DynamicTensorDescriptor_v2 ...@@ -156,7 +156,7 @@ struct DynamicTensorDescriptor_v2
hidden_lengths(0) = element_space_size; hidden_lengths(0) = element_space_size;
// lengths for all other hidden dimensions // lengths for all other hidden dimensions
static_for<0, ntransform_, 1>{}([&](auto itran) { static_for<0, ntransform_, 1>{}([&transforms, &hidden_lengths](auto itran) {
const auto& tran = transforms.At(itran); const auto& tran = transforms.At(itran);
constexpr auto up_dim_ids = UpperDimensionIdss{}.At(itran); constexpr auto up_dim_ids = UpperDimensionIdss{}.At(itran);
...@@ -408,17 +408,12 @@ make_dynamic_tensor_coordinate_v2(const TensorDesc& tensor_desc, const VisibleIn ...@@ -408,17 +408,12 @@ make_dynamic_tensor_coordinate_v2(const TensorDesc& tensor_desc, const VisibleIn
MultiIndex<ndim_hidden> idx_hidden; MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index
auto idx_hidden_pick_visible = pick_array_element(idx_hidden, visible_dim_ids); auto idx_hidden_pick_visible = pick_array_element(idx_hidden, visible_dim_ids);
idx_hidden_pick_visible = idx_visible;
// initialize visible index
#pragma unroll
for(index_t i = 0; i < ndim_visible; ++i)
{
idx_hidden_pick_visible(i) = idx_visible[i];
}
// calculate hidden index // calculate hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden](auto itran) {
const auto& tran = tensor_desc.GetTransforms().At(itran); const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
...@@ -451,13 +446,11 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id ...@@ -451,13 +446,11 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
auto non_zero_diff_pick_visible = pick_array_element(non_zero_diff, visible_dim_ids); auto non_zero_diff_pick_visible = pick_array_element(non_zero_diff, visible_dim_ids);
#pragma unroll static_for<0, ndim_visible, 1>{}([&non_zero_diff_pick_visible, &idx_diff_visible](auto i) {
for(index_t i = 0; i < ndim_visible; ++i)
{
non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0);
} });
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { static_for<ntransform - 1, -1, -1>{}([&do_transforms, &non_zero_diff](auto itran) {
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
...@@ -473,11 +466,10 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id ...@@ -473,11 +466,10 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
do_transforms(itran) = idx_diff_up_has_non_zero; do_transforms(itran) = idx_diff_up_has_non_zero;
#pragma unroll static_for<0, dims_low.Size(), 1>{}(
for(index_t i = 0; i < dims_low.Size(); ++i) [&non_zero_diff_pick_low, &idx_diff_up_has_non_zero](auto i) {
{
non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero;
} });
}); });
return DynamicTensorCoordinateStep_v2<ntransform, ndim_visible>{idx_diff_visible, return DynamicTensorCoordinateStep_v2<ntransform, ndim_visible>{idx_diff_visible,
...@@ -498,16 +490,12 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten ...@@ -498,16 +490,12 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
// this is what needs to be calculated // this is what needs to be calculated
auto idx_diff_hidden = HiddenIndex{0}; auto idx_diff_hidden = HiddenIndex{0};
// initialize visible index diff
// idx_diff_hidden_pick_visible contains reference to idx_diff_hidden // idx_diff_hidden_pick_visible contains reference to idx_diff_hidden
auto idx_diff_hidden_pick_visible = auto idx_diff_hidden_pick_visible =
pick_array_element(idx_diff_hidden, TensorDesc::GetVisibleDimensionIds()); pick_array_element(idx_diff_hidden, TensorDesc::GetVisibleDimensionIds());
// initialize visible index diff idx_diff_hidden_pick_visible = coord_step.GetVisibleIndexDiff();
#pragma unroll
for(index_t i = 0; i < ndim_visible; ++i)
{
idx_diff_hidden_pick_visible(i) = coord_step.GetVisibleIndexDiff()[i];
}
// this is what needs to be updated // this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex(); auto& idx_hidden = coord.GetHiddenIndex();
...@@ -518,7 +506,7 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten ...@@ -518,7 +506,7 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
idx_hidden_pick_visible += coord_step.GetIndexDiff(); idx_hidden_pick_visible += coord_step.GetIndexDiff();
// update rest of hidden index // update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden, &idx_diff_hidden](auto itran) {
const auto& tran = tensor_desc.GetTransforms().At(itran); const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
...@@ -548,7 +536,7 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te ...@@ -548,7 +536,7 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
const auto& idx_hidden = coord.GetHiddenIndex(); const auto& idx_hidden = coord.GetHiddenIndex();
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden, &valid](auto itran) {
const auto tran = tensor_desc.GetTransforms().At(itran); const auto tran = tensor_desc.GetTransforms().At(itran);
// check validity, only if current transformation does not always has a valid mapping // check validity, only if current transformation does not always has a valid mapping
...@@ -573,12 +561,12 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& ...@@ -573,12 +561,12 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
bool is_visible_index_valid = true; bool is_visible_index_valid = true;
#pragma unroll static_for<0, TensorDesc::GetNumOfDimension(), 1>{}(
for(index_t i = 0; i < TensorDesc::GetNumOfDimension(); ++i) [&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) {
{ is_visible_index_valid =
is_visible_index_valid = is_visible_index_valid && is_visible_index_valid &&
(idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i)); (idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i));
} });
// check other hidden index // check other hidden index
return is_visible_index_valid && return is_visible_index_valid &&
......
...@@ -5,9 +5,6 @@ ...@@ -5,9 +5,6 @@
namespace ck { namespace ck {
template <index_t... Is>
struct Sequence;
template <typename X, typename Y> template <typename X, typename Y>
struct is_same : public integral_constant<bool, false> struct is_same : public integral_constant<bool, false>
{ {
...@@ -18,26 +15,17 @@ struct is_same<X, X> : public integral_constant<bool, true> ...@@ -18,26 +15,17 @@ struct is_same<X, X> : public integral_constant<bool, true>
{ {
}; };
template <typename>
struct is_static : integral_constant<bool, false>
{
};
template <typename T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
template <typename T> template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type; using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T> template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type; using remove_cv_t = typename std::remove_cv<T>::type;
template <class T>
constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
{
return static_cast<typename std::remove_reference<T>::type&&>(t);
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -192,7 +192,7 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -192,7 +192,7 @@ void device_dummy_dynamic_transform_v2(InDesc,
const auto in_gemmk_gemmn_coord_step = const auto in_gemmk_gemmn_coord_step =
make_dynamic_tensor_coordinate_step_v2(in_gemmk_gemmn_global_desc, MultiIndex<2>{0, 1}); make_dynamic_tensor_coordinate_step_v2(in_gemmk_gemmn_global_desc, MultiIndex<2>{0, 1});
for(index_t iter = 0; iter < 100; ++iter) for(index_t iter = 0; iter < 20; ++iter)
{ {
printf("iter %d\n", iter); printf("iter %d\n", iter);
print_array_v2("visible idx: ", in_gemmk_gemmn_coord.GetIndex()); print_array_v2("visible idx: ", in_gemmk_gemmn_coord.GetIndex());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment