Commit 65a00640 authored by Chao Liu's avatar Chao Liu
Browse files

fix bug in tensor adaptor

parent fc148cef
......@@ -45,6 +45,7 @@ __host__ __device__ constexpr auto make_cluster_descriptor(
return ClusterDescriptor<Lengths, decltype(order)>{};
}
#if 1
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor_v2(
......@@ -64,9 +65,10 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2(
constexpr auto up_dim_new_top_ids = Sequence<0>{};
return make_simple_tensor_adaptor(
return make_single_stage_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
}
#endif
} // namespace ck
#endif
......@@ -1282,7 +1282,7 @@ struct DynamicFreeze
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0,
"wrong! inconsistent # of dimension");
idx_low = low_idx_;
......@@ -1299,7 +1299,7 @@ struct DynamicFreeze
const UpIdx& idx_up_new,
Number<Hack>)
{
idx_diff_low(Number<0>{}) = index_t{Number<0>{}};
idx_diff_low(Number<0>{}) = 0;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
......@@ -1328,5 +1328,90 @@ struct DynamicFreeze
}
};
template <typename VectorSize, typename UpLength>
struct DynamicVectorize
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
UpLengths up_lengths_;
VectorSize vector_size_;
__host__ __device__ constexpr DynamicVectorize() = default;
__host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size,
const UpLength& up_length)
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}];
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = vector_size_ * idx_diff_up[I0];
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicVectorize, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};
} // namespace ck
#endif
......@@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return DynamicFreeze<LowerIndex>{low_idx};
}
template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length)
{
return DynamicVectorize<VectorSize, UpLength>{vector_size, up_length};
}
} // namespace ck
#endif
......@@ -235,15 +235,31 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
constexpr index_t ndim_low =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor1_min_hidden_id =
math::min(adaptor1_min_hidden_id,
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
constexpr index_t low_dim_hidden_id =
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value;
bool is_bottom_dim = false;
static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) {
if constexpr(low_dim_hidden_id ==
TensorAdaptor1::GetBottomDimensionHiddenIds()[i])
{
is_bottom_dim = true;
}
});
if(!is_bottom_dim)
{
adaptor1_min_hidden_id = math::min(adaptor1_min_hidden_id, low_dim_hidden_id);
}
});
constexpr index_t ndim_up =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension();
// get the min of all upper dimensions
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor1_min_hidden_id =
math::min(adaptor1_min_hidden_id,
......@@ -255,7 +271,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}();
constexpr index_t adaptor1_hidden_id_shift =
adaptor1_min_hidden_id - adaptor0_max_hidden_id + 1;
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension();
......@@ -276,7 +292,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod(idim_low_1) -= adaptor1_hidden_id_shift;
low_dim_hidden_ids_1_mod(idim_low_1) += adaptor1_hidden_id_shift;
});
// match hidden id
......@@ -322,7 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift hidden id
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
up_dim_hidden_ids_1_mod(idim_up_1) -= adaptor1_hidden_id_shift;
up_dim_hidden_ids_1_mod(idim_up_1) += adaptor1_hidden_id_shift;
});
return up_dim_hidden_ids_1_mod;
......@@ -344,23 +360,23 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr auto top_dim_hidden_ids =
TensorAdaptor1::GetTopDimensionHiddenIds() - Number<adaptor1_hidden_id_shift>{};
TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{};
// put everything together
return TensorAdaptor<decltype(all_transforms),
decltype(all_low_dim_hidden_idss),
decltype(all_up_dim_hidden_idss),
decltype(bottom_dim_hidden_ids),
decltype(top_dim_hidden_ids)>{all_transforms};
return TensorAdaptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
__host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::Size();
......@@ -400,11 +416,19 @@ __host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms&
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<Transforms,
decltype(low_dim_hidden_idss),
decltype(up_dim_hidden_idss),
decltype(bottom_dim_hidden_ids),
decltype(top_dim_hidden_ids)>{transforms};
return TensorAdaptor<remove_cv_t<Transforms>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
}
template <typename X,
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
}
} // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment