Commit 32a2d78b authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use indirect base type to generate methods

parent e53b50e8
......@@ -207,6 +207,12 @@ inline constexpr bool is_device_op_v = is_device_op<T>::value;
} // namespace detail
template <typename Range>
auto front(Range&& range) -> decltype(std::forward<Range>(range).front())
{
return std::forward<Range>(range).front();
}
template <typename Axes>
inline std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
is_valid_axes(const Axes& axes)
......
......@@ -21,22 +21,22 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf.ToDevice(a.mData.data());
std::array<ck::index_t, 4> ab_lengths;
std::array<std::array<ck::index_t, 4>, 1> a_strides, b_strides;
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides, b_strides;
std::copy(begin(shape), end(shape), begin(ab_lengths));
std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides));
std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides));
std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(front(a_strides)));
std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(front(b_strides)));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
static_assert(detail::is_device_op_v<DevicePermuteInstance>);
// static_assert(detail::is_device_op_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{};
auto argument =
permute.MakeArgument(ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
permute.MakeArgument(ab_lengths, a_strides, b_strides, input, output, PassThrough{});
if(!permute.IsSupportedArgument(argument))
{
......
......@@ -18,10 +18,38 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace detail {
template <typename Derived>
struct DevicePermuteBase : BaseOperator
{
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* argument = dynamic_cast<const typename Derived::Argument*>(arg);
if(!argument)
{
return false;
}
return Derived::IsSupportedArgument(*argument);
}
template <typename... Args>
static auto MakeArgument(Args&&... args)
{
return typename Derived::Argument{std::forward<Args>(args)...};
}
template <typename... Args>
static auto MakeArgumentPointer(Args&&... args)
{
return std::make_unique<typename Derived::Argument>(std::forward<Args>(args)...);
}
static auto MakeInvoker() { return typename Derived::Invoker{}; }
static auto MakeInvokerPointer() { return std::make_unique<typename Derived::Invoker>(); };
};
} // namespace detail
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
......@@ -30,7 +58,7 @@ template <typename InDataTypeTuple,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DevicePermute : DevicePermuteBase<DevicePermute<InDataTypeTuple,
struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim,
......@@ -264,46 +292,6 @@ struct DevicePermute : DevicePermuteBase<DevicePermute<InDataTypeTuple,
return valid;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
{
return Argument{lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); };
}; // namespace device
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment