Commit 4d13badd authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent b23e7f8e
......@@ -156,7 +156,7 @@ struct DynamicTensorDescriptor_v2
hidden_lengths(0) = element_space_size;
// lengths for all other hidden dimensions
static_for<0, ntransform_, 1>{}([&](auto itran) {
static_for<0, ntransform_, 1>{}([&transforms, &hidden_lengths](auto itran) {
const auto& tran = transforms.At(itran);
constexpr auto up_dim_ids = UpperDimensionIdss{}.At(itran);
......@@ -408,17 +408,12 @@ make_dynamic_tensor_coordinate_v2(const TensorDesc& tensor_desc, const VisibleIn
MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index
auto idx_hidden_pick_visible = pick_array_element(idx_hidden, visible_dim_ids);
// initialize visible index
#pragma unroll
for(index_t i = 0; i < ndim_visible; ++i)
{
idx_hidden_pick_visible(i) = idx_visible[i];
}
idx_hidden_pick_visible = idx_visible;
// calculate hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden](auto itran) {
const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
......@@ -451,13 +446,11 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
auto non_zero_diff_pick_visible = pick_array_element(non_zero_diff, visible_dim_ids);
#pragma unroll
for(index_t i = 0; i < ndim_visible; ++i)
{
static_for<0, ndim_visible, 1>{}([&non_zero_diff_pick_visible, &idx_diff_visible](auto i) {
non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0);
}
});
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
static_for<ntransform - 1, -1, -1>{}([&do_transforms, &non_zero_diff](auto itran) {
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
......@@ -473,11 +466,10 @@ make_dynamic_tensor_coordinate_step_v2(const TensorDesc&, const VisibleIndex& id
do_transforms(itran) = idx_diff_up_has_non_zero;
#pragma unroll
for(index_t i = 0; i < dims_low.Size(); ++i)
{
non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero;
}
static_for<0, dims_low.Size(), 1>{}(
[&non_zero_diff_pick_low, &idx_diff_up_has_non_zero](auto i) {
non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero;
});
});
return DynamicTensorCoordinateStep_v2<ntransform, ndim_visible>{idx_diff_visible,
......@@ -498,16 +490,12 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
// this is what needs to be calculated
auto idx_diff_hidden = HiddenIndex{0};
// idx_diff_hidden_pick_visible contains reference to idx_diff_hidden
// initialize visible index diff
// idx_diff_hidden_pick_visible contains reference to idx_diff_hidden
auto idx_diff_hidden_pick_visible =
pick_array_element(idx_diff_hidden, TensorDesc::GetVisibleDimensionIds());
// initialize visible index diff
#pragma unroll
for(index_t i = 0; i < ndim_visible; ++i)
{
idx_diff_hidden_pick_visible(i) = coord_step.GetVisibleIndexDiff()[i];
}
idx_diff_hidden_pick_visible = coord_step.GetVisibleIndexDiff();
// this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex();
......@@ -518,7 +506,7 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
idx_hidden_pick_visible += coord_step.GetIndexDiff();
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden, &idx_diff_hidden](auto itran) {
const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
......@@ -548,7 +536,7 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
const auto& idx_hidden = coord.GetHiddenIndex();
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden, &valid](auto itran) {
const auto tran = tensor_desc.GetTransforms().At(itran);
// check validity, only if current transformation does not always has a valid mapping
......@@ -573,12 +561,12 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
bool is_visible_index_valid = true;
#pragma unroll
for(index_t i = 0; i < TensorDesc::GetNumOfDimension(); ++i)
{
is_visible_index_valid = is_visible_index_valid &&
(idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i));
}
static_for<0, TensorDesc::GetNumOfDimension(), 1>{}(
[&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) {
is_visible_index_valid =
is_visible_index_valid &&
(idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i));
});
// check other hidden index
return is_visible_index_valid &&
......
......@@ -5,9 +5,6 @@
namespace ck {
template <index_t... Is>
struct Sequence;
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
......@@ -18,26 +15,17 @@ struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename>
struct is_static : integral_constant<bool, false>
{
};
template <typename T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
template <class T>
constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
{
return static_cast<typename std::remove_reference<T>::type&&>(t);
}
} // namespace ck
#endif
......@@ -192,7 +192,7 @@ void device_dummy_dynamic_transform_v2(InDesc,
const auto in_gemmk_gemmn_coord_step =
make_dynamic_tensor_coordinate_step_v2(in_gemmk_gemmn_global_desc, MultiIndex<2>{0, 1});
for(index_t iter = 0; iter < 100; ++iter)
for(index_t iter = 0; iter < 20; ++iter)
{
printf("iter %d\n", iter);
print_array_v2("visible idx: ", in_gemmk_gemmn_coord.GetIndex());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment