"...git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "d99469a5bd82eed5b1d4605d3ef04b945d86b1a2"
Commit bd44e639 authored by Chao Liu's avatar Chao Liu
Browse files

adding dimension transformation

parent cb6475c7
#ifndef CK_DIMENSION_HPP
#define CK_DIMENSION_HPP
#include "common_header.hpp"
namespace ck {
template <index_t Length>
struct Dimension
{
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
};
template <index_t Length, index_t Stride>
struct NativeDimension : Dimension<Length>
{
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
__host__ __device__ static constexpr index_t GetOffset(index_t id) { return id * Stride; }
__host__ __device__ static constexpr index_t GetOffsetDiff(index_t id_diff)
{
return id_diff * Stride;
}
};
} // namespace ck
#endif
#ifndef CK_DIMENSION_TRANSFORM_HPP
#define CK_DIMENSION_TRANSFORM_HPP
#include "common_header.hpp"
namespace ck {
template <index_t N>
using MultiIndex = Array<index_t, N>;
// LowLengths: Sequence<...>
template <class LowLengths>
struct PassThrough
{
static constexpr index_t nDim = LowLengths::GetSize();
using LowerId = MultiIndex<nDim>;
using UpperId = LowerId;
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperNumOfDimension()
{
return GetLowerNumOfDimension();
}
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); }
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up; }
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
{
return id_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
};
// LowLengths: Sequence<...>
template <class LowLengths, class LeftPads, class RightPads>
struct Pad
{
static constexpr index_t nDim = LowLengths::GetSize();
using LowerId = MultiIndex<nDim>;
using UpperId = LowerId;
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperNumOfDimension()
{
return GetLowerNumOfDimension();
}
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return GetLowerLengths() + LeftPads + RightPads;
}
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up - LeftPads; }
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
{
return id_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
};
// LowLengths: Sequence<...>
template <class LowLengths>
struct Merge
{
static constexpr index_t nDimLow = LowLengths::GetSize();
static constexpr index_t nDimUp = 1;
using LowerId = MultiIndex<nDimLow>;
using UpperId = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return Sequence<accumulate_on_sequence(
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{};
}
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
{
LowerId id_low;
// not implemeneted
return id_low;
}
// id_low_diff depends on id_low_old, so id_low need to be up-to-date
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff, LowerId id_low_old)
{
LowerId id_low_diff;
// not implemeneted
return id_low_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
};
// UpLengths: Sequence<...>
template <index_t LowLength, class UpLengths>
struct Unmerge
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize();
__host__ __device__ constexpr Unmerge()
{
static_assert(LowLength == accumulate_on_sequence(
UpLengths{}, math::multiplies<index_t>{}, Number<1>{}),
"wrong! UpLengths need to be ");
}
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
{
constexpr auto scans = typename sequence_reverse_inclusive_scan<UpLengths,
math::multiplies<index_t>,
1>::type{};
LowerId id_low{0};
static_for<0, nDim, 1>{}([&](auto idim) { id_low[0] += id_up[idim] * scans[idim]; });
return id_low;
}
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
{
return GetLowerId(id_up_diff);
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
};
// UpLengths: Sequence<...>
// Coefficients: Sequence<...>
// id_low = coefficients[0, ...nDimUp-1] * id_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <index_t LowLength, class UpLengths, class Coefficients>
struct Embed
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize();
static constexpr auto mCoefficients = Coefficients{};
__host__ __device__ constexpr Embed()
{
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
"wrong! # of dimensions not consistent");
constexpr index_t low_id_max =
Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
math::plus<index_t>{},
Number<0>{});
static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range");
}
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
{
LowerId id_low{mCoefficients[nDimUp]};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { id_low[0] += id_up[idim] * mCoefficients[idim]; });
return id_low;
}
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
{
LowerId id_low_diff{0};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { id_low_diff[0] += id_up_diff[idim] * mCoefficients[idim]; });
return id_low_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
};
} // namespace ck
#endif
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "dimension.hpp"
namespace ck {
template <class Lengths, class Strides>
struct NativeTensorDescriptor
{
using type = NativeTensorDescriptor;
static constexpr index_t nDim = Lengths::GetSize();
using Id = MultiIndex<nDim>;
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
__host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; }
__host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; }
__host__ __device__ static constexpr index_t GetOffset(Id id)
{
// not implemented
}
};
// LowerTensorDescriptor
// Transforms: std::tuple<DimensionTransforms...>
// LowerIds: std::tuple<Sequence<...>>
// UpperIds: std::tuple<Sequence<...>>
template <class LowTensorDescriptor,
class Transforms,
class LowDimensionMasks,
class UpDimensionMasks>
struct TransformedTensorDescriptor
{
using type = TransformedTensorDescriptor;
static constexpr index_t nDimUp = xxxx;
static constexpr index_t nDimLow = xxx;
static constexpr index_t nTransform = Transforms::GetSize();
using UpperId = MultiIndex<nDimUp>;
using LowerId = MultiIndex<nDimLow>;
__host__ __device__ static constexpr TransformedTensorDescriptor()
{
static_assert(nTransform == Transforms::GetSize() &&
nTransform == LowDimensionMasks::GetSize() &&
nTransform == UpDimensionMasks::GetSize(),
"wrong! # of transformations not the same");
// TODO: sanity check: LowDimensionMasks should include all low-dimensions,
// UpDimensionMasks should include all up-dimensions
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation,
// a low-dimension should be associated with only one transformation
}
__host__ __device__ static constexpr auto GetNumOfDimension()
{
// not implemented
}
__host__ __device__ static constexpr auto GetLengths()
{
// not implemented
}
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
{
return LowTensorDescriptor{};
}
__host__ __device__ static constexpr index_t GetLowerId(UpperId id_up)
{
// not implemented
}
__host__ __device__ static constexpr index_t GetOffset(UpperId id_up)
{
return GetLowerTensorDescriptor().GetOffset(GetLowerId(id_up));
}
__host__ __device__ static constexpr auto AreUpperId2OffsetLinear();
{
// not implemented
}
__host__ __device__ static constexpr auto GetIndependentDimensionGroups()
{
// not implemented
}
};
} // namespace ck
#endif
......@@ -4,6 +4,7 @@
#include "config.hpp"
#include "utility.hpp"
#include "integral_constant.hpp"
#include "tuple.hpp"
#include "math.hpp"
#include "vector_type.hpp"
#include "Sequence.hpp"
......
#ifndef CK_TUPLE_HPP
#define CK_TUPLE_HPP
#include "integral_constant.hpp"
namespace ck {
template <class... Ts>
struct tuple : public std::tuple<Ts...>
{
using type = tuple;
__host__ __device__ static constexpr index_t GetSize() { return std::tuple_size(tuple{}); }
template <index_t I>
__host__ __device__ constexpr auto Get(Number<I>) const
{
return std::get<I>(*this);
}
template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const
{
return Get(Number<I>{}) :
}
};
// merge tuple
template <class... Tuples>
__host__ __device__ constexpr auto merge_tuple(Tuples&&... xs)
{
return std::tuple_cat(xs...);
};
// generate sequence
template <index_t IBegin, index_t NRemain, class F>
struct tuple_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type =
typename tuple_merge<typename tuple_gen_impl<IBegin, NRemainLeft, F>::type,
typename tuple_gen_impl<IMiddle, NRemainRight, F>::type>::type;
};
template <index_t I, class F>
struct tuple_gen_impl<I, 1, F>
{
static constexpr auto x = F{}(Number<I>{});
using type = tuple<Is>;
};
template <index_t I, class F>
struct sequence_gen_impl<I, 0, F>
{
using type = Sequence<>;
};
template <index_t NSize, class F>
struct sequence_gen
{
using type = typename sequence_gen_impl<0, NSize, F>::type;
};
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment