Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
71c8181c
"include/vscode:/vscode.git/clone" did not exist on "e96eb48724924a5fc4b5aa6f548f262bf75f230d"
Unverified
Commit
71c8181c
authored
Apr 19, 2023
by
Umang Yadav
Committed by
GitHub
Apr 19, 2023
Browse files
Update multi() to work with non-std shapes (#1690)
Solves #1311
parent
f92e7994
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
19 deletions
+33
-19
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+8
-4
src/shape.cpp
src/shape.cpp
+12
-15
test/shape_test.cpp
test/shape_test.cpp
+13
-0
No files found.
src/include/migraphx/shape.hpp
View file @
71c8181c
...
@@ -222,11 +222,15 @@ struct shape
...
@@ -222,11 +222,15 @@ struct shape
/// Map element index to space index
/// Map element index to space index
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
/// Map element index to multi-dimensional index
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
idx
)
const
;
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// Map element index to multi-dimensional index and put them them into location provided by
/// padding
/// pointers
void
multi_copy
(
std
::
size_t
idx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Returns true if the shape is packed (number of elements and buffer size the same) with
/// no padding
bool
packed
()
const
;
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// Returns true is the shape has been transposed. That is the strides are not in descending
...
...
src/shape.cpp
View file @
71c8181c
...
@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
...
@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
}
}
}
}
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
)
const
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
dx
)
const
{
{
assert
(
this
->
standard
());
assert
(
idx
<
elements
());
std
::
vector
<
std
::
size_t
>
indices
(
lens
().
size
());
std
::
vector
<
std
::
size_t
>
indices
(
lens
().
size
());
multi_copy
(
i
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
multi_copy
(
idx
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
return
indices
;
return
indices
;
}
}
void
shape
::
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
void
shape
::
multi_copy
(
std
::
size_t
i
dx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
{
{
assert
(
this
->
standard
())
;
size_t
tidx
=
idx
;
(
void
)
end
;
(
void
)
end
;
assert
(
idx
<
elements
());
assert
(
lens
().
size
()
<=
(
end
-
start
));
assert
(
lens
().
size
()
<=
(
end
-
start
));
std
::
transform
(
strides
().
begin
(),
for
(
size_t
ii
=
lens
().
size
()
-
1
;
ii
>
0
;
ii
--
)
strides
().
end
(),
{
lens
().
begin
(),
*
(
start
+
ii
)
=
tidx
%
lens
()[
ii
];
start
,
tidx
=
tidx
/
lens
()[
ii
];
[
&
](
std
::
size_t
stride
,
std
::
size_t
len
)
{
}
assert
(
len
>
0
and
stride
>
0
);
*
start
=
tidx
;
return
(
i
/
stride
)
%
len
;
});
}
}
bool
shape
::
packed
()
const
bool
shape
::
packed
()
const
...
...
test/shape_test.cpp
View file @
71c8181c
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <array>
#include <array>
#include <algorithm>
#include <algorithm>
#include <numeric>
#include <numeric>
#include <migraphx/verify.hpp>
#include "test.hpp"
#include "test.hpp"
TEST_CASE
(
test_shape_default
)
TEST_CASE
(
test_shape_default
)
...
@@ -929,4 +930,16 @@ TEST_CASE(test_with_type)
...
@@ -929,4 +930,16 @@ TEST_CASE(test_with_type)
EXPECT
(
s
.
strides
()
==
new_s
.
strides
());
EXPECT
(
s
.
strides
()
==
new_s
.
strides
());
}
}
TEST_CASE
(
test_multi_index
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
6
}};
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
0
),
std
::
vector
<
size_t
>
{
0
,
0
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
4
),
std
::
vector
<
size_t
>
{
0
,
0
,
4
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
6
),
std
::
vector
<
size_t
>
{
0
,
1
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
8
),
std
::
vector
<
size_t
>
{
0
,
1
,
2
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
24
),
std
::
vector
<
size_t
>
{
1
,
0
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
30
),
std
::
vector
<
size_t
>
{
1
,
1
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
34
),
std
::
vector
<
size_t
>
{
1
,
1
,
4
}));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment