Commit 5b63400a authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Passing 'axes' to 'DevicePermute'

parent 50f5ce49
...@@ -21,22 +21,23 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem) ...@@ -21,22 +21,23 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf.ToDevice(a.mData.data()); a_device_buf.ToDevice(a.mData.data());
std::array<ck::index_t, 4> a_lengths, b_lengths; std::array<ck::index_t, 4> a_lengths;
std::array<ck::index_t, 4> axes;
std::array<ck::index_t, 4> a_strides, b_strides; std::array<ck::index_t, 4> a_strides, b_strides;
const void* input = a_device_buf.GetDeviceBuffer(); const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer(); void* output = b_device_buf.GetDeviceBuffer();
std::copy(begin(shape), end(shape), begin(a_lengths)); std::copy(begin(shape), end(shape), begin(a_lengths));
std::copy(begin(transposed_shape), end(transposed_shape), begin(b_lengths)); std::copy(begin(problem.axes), end(problem.axes), begin(axes));
std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides)); std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides));
std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides)); std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>); static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{}; auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument( auto argument =
a_lengths, a_strides, b_lengths, b_strides, input, output, PassThrough{}); permute.MakeArgument(a_lengths, axes, a_strides, b_strides, input, output, PassThrough{});
if(!permute.IsSupportedArgument(argument)) if(!permute.IsSupportedArgument(argument))
{ {
......
...@@ -162,8 +162,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -162,8 +162,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::array<index_t, NumDim> inLengths, Argument(const std::array<index_t, NumDim> inLengths,
const std::array<index_t, NumDim> axes,
const std::array<index_t, NumDim> inStrides, const std::array<index_t, NumDim> inStrides,
const std::array<index_t, NumDim> outLengths,
const std::array<index_t, NumDim> outStrides, const std::array<index_t, NumDim> outStrides,
const void* in_dev_buffer, const void* in_dev_buffer,
void* out_dev_buffer, void* out_dev_buffer,
...@@ -171,8 +171,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -171,8 +171,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
: blockSize_(256), : blockSize_(256),
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
inLengths_(inLengths), inLengths_(inLengths),
axes_(axes),
inStridesArray_({inStrides}), inStridesArray_({inStrides}),
outLengths_(outLengths),
outStridesArray_({outStrides}), outStridesArray_({outStrides}),
elementwise_op_(elementwise_op) elementwise_op_(elementwise_op)
{ {
...@@ -196,7 +196,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -196,7 +196,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
out_grid_1d_desc_tuple_ = generate_tuple( out_grid_1d_desc_tuple_ = generate_tuple(
[&](auto) { [&](auto) {
return MakeDescriptor_M(outLengths, outStrides, gridSize_, blockSize_); return MakeDescriptor_M(inLengths, outStrides, gridSize_, blockSize_);
}, },
Number<NumOutput>{}); Number<NumOutput>{});
} }
...@@ -210,8 +210,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -210,8 +210,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
OutGrid1dDescTuple out_grid_1d_desc_tuple_; OutGrid1dDescTuple out_grid_1d_desc_tuple_;
std::array<index_t, NumDim> inLengths_; std::array<index_t, NumDim> inLengths_;
std::array<index_t, NumDim> axes_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_; std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<index_t, NumDim> outLengths_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_; std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_; ElementwiseOperation elementwise_op_;
...@@ -244,7 +244,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -244,7 +244,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(arg.inLengths_.back() % MPerThread == 0 && arg.outLengths_.back() % MPerThread == 0)) if(arg.inLengths_.back() % MPerThread != 0)
{ {
return false; return false;
} }
...@@ -270,7 +270,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -270,7 +270,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
arg.outLengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) arg.inLengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment