Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
71c8181c
Unverified
Commit
71c8181c
authored
Apr 19, 2023
by
Umang Yadav
Committed by
GitHub
Apr 19, 2023
Browse files
Update multi() to work with non-std shapes (#1690)
Solves #1311
parent
f92e7994
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
19 deletions
+33
-19
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+8
-4
src/shape.cpp
src/shape.cpp
+12
-15
test/shape_test.cpp
test/shape_test.cpp
+13
-0
No files found.
src/include/migraphx/shape.hpp
View file @
71c8181c
...
...
@@ -222,11 +222,15 @@ struct shape
/// Map element index to space index
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Map element index to multi-dimensional index
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
idx
)
const
;
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
/// Map element index to multi-dimensional index and put them them into location provided by
/// pointers
void
multi_copy
(
std
::
size_t
idx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Returns true if the shape is packed (number of elements and buffer size the same) with
/// no padding
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
...
...
src/shape.cpp
View file @
71c8181c
...
...
@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
}
}
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
)
const
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
dx
)
const
{
assert
(
this
->
standard
());
assert
(
idx
<
elements
());
std
::
vector
<
std
::
size_t
>
indices
(
lens
().
size
());
multi_copy
(
i
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
multi_copy
(
idx
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
return
indices
;
}
void
shape
::
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
void
shape
::
multi_copy
(
std
::
size_t
i
dx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
{
assert
(
this
->
standard
())
;
size_t
tidx
=
idx
;
(
void
)
end
;
assert
(
idx
<
elements
());
assert
(
lens
().
size
()
<=
(
end
-
start
));
std
::
transform
(
strides
().
begin
(),
strides
().
end
(),
lens
().
begin
(),
start
,
[
&
](
std
::
size_t
stride
,
std
::
size_t
len
)
{
assert
(
len
>
0
and
stride
>
0
);
return
(
i
/
stride
)
%
len
;
});
for
(
size_t
ii
=
lens
().
size
()
-
1
;
ii
>
0
;
ii
--
)
{
*
(
start
+
ii
)
=
tidx
%
lens
()[
ii
];
tidx
=
tidx
/
lens
()[
ii
];
}
*
start
=
tidx
;
}
bool
shape
::
packed
()
const
...
...
test/shape_test.cpp
View file @
71c8181c
...
...
@@ -30,6 +30,7 @@
#include <array>
#include <algorithm>
#include <numeric>
#include <migraphx/verify.hpp>
#include "test.hpp"
TEST_CASE
(
test_shape_default
)
...
...
@@ -929,4 +930,16 @@ TEST_CASE(test_with_type)
EXPECT
(
s
.
strides
()
==
new_s
.
strides
());
}
TEST_CASE
(
test_multi_index
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
6
}};
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
0
),
std
::
vector
<
size_t
>
{
0
,
0
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
4
),
std
::
vector
<
size_t
>
{
0
,
0
,
4
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
6
),
std
::
vector
<
size_t
>
{
0
,
1
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
8
),
std
::
vector
<
size_t
>
{
0
,
1
,
2
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
24
),
std
::
vector
<
size_t
>
{
1
,
0
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
30
),
std
::
vector
<
size_t
>
{
1
,
1
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
34
),
std
::
vector
<
size_t
>
{
1
,
1
,
4
}));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment