Unverified Commit 68446f7a authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Test and doc update for shape.from_permutation() (#1742)

Changed the doc for find_permutation(shape) to be more clear that it is finding the permutation that would make the shape standard
parent 5df11e0f
......@@ -56,12 +56,12 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
}
/*!
* Returns the permutation needed to apply to the shape to undo the current permutation
* Returns the inverse permutation that could be applied to undo the inputted permutation
*/
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
/*!
* Finds the permutation most likely from a transpose operator that has been applied to the shape.
* Finds the permutation that would make the shape not transposed (refering to shape.transposed())
*/
std::vector<int64_t> find_permutation(const shape& s);
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
......
......@@ -156,8 +156,28 @@ struct shape
shape(const std::vector<shape>& subs);
/**
* Creates an output shape with dimensions equal to the input lengths and strides determined
* by the permutation argument such that find_permutation() of the output shape returns the
* inputted permuation.
*
* 2D example:
* parameters:
* l = [2, 3], perm = [1, 0]
* therefore:
* "original" shape = {lens = [3, 2], strides = [2, 1]}
* output_shape = {lens = [2, 3], strides = [1, 2]
*
* 3D example:
* parameters:
* l = [2, 3, 4], perm = [1, 2, 0]
* therefore:
* "original" shape = {lens = [3, 4, 2], strides = [8, 2, 1]}
* output_shape = {lens = [2, 3, 4], strides = [1, 8, 2]}
*/
static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const;
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
......
......@@ -956,4 +956,67 @@ TEST_CASE(test_multi_index)
EXPECT(migraphx::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
}
TEST_CASE(find_permutation_2d_standard)
{
migraphx::shape s = {migraphx::shape::float_type, {2, 3}};
std::vector<int64_t> permutation = {0, 1};
EXPECT(migraphx::find_permutation(s) == permutation);
}
TEST_CASE(find_permutation_2d_transpose)
{
migraphx::shape s = {migraphx::shape::float_type, {2, 3}, {1, 2}};
std::vector<int64_t> permutation = {1, 0};
EXPECT(migraphx::find_permutation(s) == permutation);
}
TEST_CASE(find_permutation_3d)
{
migraphx::shape s = {migraphx::shape::float_type, {2, 3, 4}, {1, 8, 2}};
std::vector<int64_t> permutation = {1, 2, 0};
EXPECT(migraphx::find_permutation(s) == permutation);
}
TEST_CASE(find_permutation_4d)
{
// ori_lens = 2, 3, 4, 5
// ori_strides = 60, 20, 5, 1
// perm = 3, 2, 0, 1
// inv_perm = 2, 3, 1, 0
// out_strides = 5, 1, 20, 60
migraphx::shape s = {migraphx::shape::float_type, {5, 4, 2, 3}, {5, 1, 20, 60}};
std::vector<int64_t> permutation = {3, 2, 0, 1};
EXPECT(migraphx::find_permutation(s) == permutation);
}
TEST_CASE(from_2d_permutation)
{
std::vector<std::size_t> out_lens = {2, 3};
std::vector<int64_t> permutation = {1, 0};
migraphx::shape out_shape =
migraphx::shape::from_permutation(migraphx::shape::float_type, out_lens, permutation);
EXPECT(out_shape.lens() == out_lens);
EXPECT(migraphx::find_permutation(out_shape) == permutation);
}
TEST_CASE(from_3d_permutation)
{
std::vector<std::size_t> out_lens = {2, 3, 4};
std::vector<int64_t> permutation = {1, 2, 0};
migraphx::shape out_shape =
migraphx::shape::from_permutation(migraphx::shape::float_type, out_lens, permutation);
EXPECT(out_shape.lens() == out_lens);
EXPECT(migraphx::find_permutation(out_shape) == permutation);
}
TEST_CASE(from_4d_permutation)
{
std::vector<std::size_t> out_lens = {5, 4, 2, 3};
std::vector<int64_t> permutation = {3, 2, 0, 1};
migraphx::shape out_shape =
migraphx::shape::from_permutation(migraphx::shape::float_type, out_lens, permutation);
EXPECT(out_shape.lens() == out_lens);
EXPECT(migraphx::find_permutation(out_shape) == permutation);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment