"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6f2b310a1765177f0d7e9b4b6c8bcfe7e5d3a8a8"
Commit 65a00640 authored by Chao Liu's avatar Chao Liu
Browse files

fix bug in tensor adaptor

parent fc148cef
...@@ -45,6 +45,7 @@ __host__ __device__ constexpr auto make_cluster_descriptor( ...@@ -45,6 +45,7 @@ __host__ __device__ constexpr auto make_cluster_descriptor(
return ClusterDescriptor<Lengths, decltype(order)>{}; return ClusterDescriptor<Lengths, decltype(order)>{};
} }
#if 1
template <typename Lengths, template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type> typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor_v2( __host__ __device__ constexpr auto make_cluster_descriptor_v2(
...@@ -64,9 +65,10 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2( ...@@ -64,9 +65,10 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2(
constexpr auto up_dim_new_top_ids = Sequence<0>{}; constexpr auto up_dim_new_top_ids = Sequence<0>{};
return make_simple_tensor_adaptor( return make_single_stage_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
} }
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -1282,7 +1282,7 @@ struct DynamicFreeze ...@@ -1282,7 +1282,7 @@ struct DynamicFreeze
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const const UpIdx& idx_up) const
{ {
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low = low_idx_; idx_low = low_idx_;
...@@ -1299,7 +1299,7 @@ struct DynamicFreeze ...@@ -1299,7 +1299,7 @@ struct DynamicFreeze
const UpIdx& idx_up_new, const UpIdx& idx_up_new,
Number<Hack>) Number<Hack>)
{ {
idx_diff_low(Number<0>{}) = index_t{Number<0>{}}; idx_diff_low(Number<0>{}) = 0;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
...@@ -1328,5 +1328,90 @@ struct DynamicFreeze ...@@ -1328,5 +1328,90 @@ struct DynamicFreeze
} }
}; };
template <typename VectorSize, typename UpLength>
struct DynamicVectorize
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
UpLengths up_lengths_;
VectorSize vector_size_;
__host__ __device__ constexpr DynamicVectorize() = default;
__host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size,
const UpLength& up_length)
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}];
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = vector_size_ * idx_diff_up[I0];
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicVectorize, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i ...@@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return DynamicFreeze<LowerIndex>{low_idx}; return DynamicFreeze<LowerIndex>{low_idx};
} }
template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length)
{
return DynamicVectorize<VectorSize, UpLength>{vector_size, up_length};
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -235,15 +235,31 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -235,15 +235,31 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
constexpr index_t ndim_low = constexpr index_t ndim_low =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension(); TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for<0, ndim_low, 1>{}([&](auto idim_low) { static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor1_min_hidden_id = constexpr index_t low_dim_hidden_id =
math::min(adaptor1_min_hidden_id, TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value;
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
bool is_bottom_dim = false;
static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) {
if constexpr(low_dim_hidden_id ==
TensorAdaptor1::GetBottomDimensionHiddenIds()[i])
{
is_bottom_dim = true;
}
});
if(!is_bottom_dim)
{
adaptor1_min_hidden_id = math::min(adaptor1_min_hidden_id, low_dim_hidden_id);
}
}); });
constexpr index_t ndim_up = constexpr index_t ndim_up =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension(); TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension();
// get the min of all upper dimensions
static_for<0, ndim_up, 1>{}([&](auto idim_up) { static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor1_min_hidden_id = adaptor1_min_hidden_id =
math::min(adaptor1_min_hidden_id, math::min(adaptor1_min_hidden_id,
...@@ -255,7 +271,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -255,7 +271,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}(); }();
constexpr index_t adaptor1_hidden_id_shift = constexpr index_t adaptor1_hidden_id_shift =
adaptor1_min_hidden_id - adaptor0_max_hidden_id + 1; adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension(); constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension();
...@@ -276,7 +292,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -276,7 +292,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift hidden id so every dim id is unique // shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod(idim_low_1) -= adaptor1_hidden_id_shift; low_dim_hidden_ids_1_mod(idim_low_1) += adaptor1_hidden_id_shift;
}); });
// match hidden id // match hidden id
...@@ -322,7 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -322,7 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift hidden id // shift hidden id
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
up_dim_hidden_ids_1_mod(idim_up_1) -= adaptor1_hidden_id_shift; up_dim_hidden_ids_1_mod(idim_up_1) += adaptor1_hidden_id_shift;
}); });
return up_dim_hidden_ids_1_mod; return up_dim_hidden_ids_1_mod;
...@@ -344,23 +360,23 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -344,23 +360,23 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1) // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr auto top_dim_hidden_ids = constexpr auto top_dim_hidden_ids =
TensorAdaptor1::GetTopDimensionHiddenIds() - Number<adaptor1_hidden_id_shift>{}; TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{};
// put everything together // put everything together
return TensorAdaptor<decltype(all_transforms), return TensorAdaptor<remove_cv_t<decltype(all_transforms)>,
decltype(all_low_dim_hidden_idss), remove_cv_t<decltype(all_low_dim_hidden_idss)>,
decltype(all_up_dim_hidden_idss), remove_cv_t<decltype(all_up_dim_hidden_idss)>,
decltype(bottom_dim_hidden_ids), remove_cv_t<decltype(bottom_dim_hidden_ids)>,
decltype(top_dim_hidden_ids)>{all_transforms}; remove_cv_t<decltype(top_dim_hidden_ids)>>{all_transforms};
} }
// Transforms: Tuple<transforms...> // Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...> // LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...> // UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss> template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
__host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms& transforms, __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss, LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss) UpperDimensionNewTopIdss)
{ {
constexpr index_t ntransform = Transforms::Size(); constexpr index_t ntransform = Transforms::Size();
...@@ -400,11 +416,19 @@ __host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms& ...@@ -400,11 +416,19 @@ __host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms&
constexpr auto top_dim_hidden_ids = constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{}; typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<Transforms, return TensorAdaptor<remove_cv_t<Transforms>,
decltype(low_dim_hidden_idss), remove_cv_t<decltype(low_dim_hidden_idss)>,
decltype(up_dim_hidden_idss), remove_cv_t<decltype(up_dim_hidden_idss)>,
decltype(bottom_dim_hidden_ids), remove_cv_t<decltype(bottom_dim_hidden_ids)>,
decltype(top_dim_hidden_ids)>{transforms}; remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
}
template <typename X,
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
} }
} // namespace ck } // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment