"docs/source/en/api/models.mdx" did not exist on "5e6417e9887be8f02ab5b4f5c548dff7f3a4c8f6"
Commit 50f5ce49 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Distinguish input & output shape in 'DevicePermute'

parent 2377c2e8
......@@ -21,21 +21,22 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf.ToDevice(a.mData.data());
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_lengths, b_lengths;
std::array<ck::index_t, 4> a_strides, b_strides;
const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer();
std::copy(begin(shape), end(shape), begin(ab_lengths));
std::copy(begin(shape), end(shape), begin(a_lengths));
std::copy(begin(transposed_shape), end(transposed_shape), begin(b_lengths));
std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides));
std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{};
auto argument =
permute.MakeArgument(ab_lengths, a_strides, b_strides, input, output, PassThrough{});
auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(
a_lengths, a_strides, b_lengths, b_strides, input, output, PassThrough{});
if(!permute.IsSupportedArgument(argument))
{
......
......@@ -161,16 +161,18 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
Argument(const std::array<index_t, NumDim> inLengths,
const std::array<index_t, NumDim> inStrides,
const std::array<index_t, NumDim> outLengths,
const std::array<index_t, NumDim> outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op)
: blockSize_(256),
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
lengths_(lengths),
inLengths_(inLengths),
inStridesArray_({inStrides}),
outLengths_(outLengths),
outStridesArray_({outStrides}),
elementwise_op_(elementwise_op)
{
......@@ -189,11 +191,13 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
Number<NumOutput>{});
in_grid_1d_desc_tuple_ = generate_tuple(
[&](auto) { return MakeDescriptor_M(lengths, inStrides, gridSize_, blockSize_); },
[&](auto) { return MakeDescriptor_M(inLengths, inStrides, gridSize_, blockSize_); },
Number<NumInput>{});
out_grid_1d_desc_tuple_ = generate_tuple(
[&](auto) { return MakeDescriptor_M(lengths, outStrides, gridSize_, blockSize_); },
[&](auto) {
return MakeDescriptor_M(outLengths, outStrides, gridSize_, blockSize_);
},
Number<NumOutput>{});
}
......@@ -205,8 +209,9 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
InGrid1dDescTuple in_grid_1d_desc_tuple_;
OutGrid1dDescTuple out_grid_1d_desc_tuple_;
std::array<index_t, NumDim> lengths_;
std::array<index_t, NumDim> inLengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<index_t, NumDim> outLengths_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
......@@ -239,8 +244,10 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.lengths_.back() % MPerThread != 0)
if(!(arg.inLengths_.back() % MPerThread == 0 && arg.outLengths_.back() % MPerThread == 0))
{
return false;
}
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
......@@ -257,13 +264,13 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
arg.inLengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
arg.outLengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false;
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment