Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0af471ff
Commit
0af471ff
authored
Nov 02, 2022
by
charlie
Browse files
Fixing non-standard shape literal construction
parent
1820198e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
10 deletions
+63
-10
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+5
-1
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+28
-0
test/literal_test.cpp
test/literal_test.cpp
+14
-0
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+16
-9
No files found.
src/include/migraphx/literal.hpp
View file @
0af471ff
...
...
@@ -40,6 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
/**
* @brief Represents a raw literal
* @details This stores the literal has a raw buffer that is owned by this class
* If the given shape is non-standard, the literal will be converted to a standard shape at
* construction.
*/
struct
literal
:
raw_data
<
literal
>
{
...
...
@@ -117,14 +119,16 @@ struct literal : raw_data<literal>
}
else
{
// make the literal into a standard shape (contiguous)
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
shape_for_each
_nstd
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
// NOLINT(bugprone-signed-char-misuse)
it
++
;
});
});
m_shape
=
{
m_shape
.
type
(),
m_shape
.
lens
()};
}
}
};
...
...
src/include/migraphx/shape_for_each.hpp
View file @
0af471ff
...
...
@@ -31,6 +31,10 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Iterates the given function over the standard shape indices.
* Will iterate using standard strides if given a non-standard shape.
*/
template
<
class
F
>
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
{
...
...
@@ -52,6 +56,30 @@ void shape_for_each(const migraphx::shape& s, F f)
}
}
/**
* Iterates the given function over the given shape indices.
* Will iterate using non-standard strides if given a non-standard shape.
*/
template
<
class
F
>
void
shape_for_each_nstd
(
const
migraphx
::
shape
&
s
,
F
f
)
{
// Ensure calls to f use const ref to vector
auto
call
=
[
&
f
](
const
std
::
vector
<
std
::
size_t
>&
i
)
{
f
(
i
);
};
std
::
vector
<
std
::
size_t
>
indices
(
s
.
lens
().
size
());
for
(
std
::
size_t
i
=
0
;
i
<
s
.
elements
();
i
++
)
{
std
::
transform
(
s
.
strides
().
begin
(),
s
.
strides
().
end
(),
s
.
lens
().
begin
(),
indices
.
begin
(),
[
&
](
std
::
size_t
stride
,
std
::
size_t
len
)
{
assert
(
len
>
0
and
stride
>
0
);
return
(
i
/
stride
)
%
len
;
});
call
(
indices
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
test/literal_test.cpp
View file @
0af471ff
...
...
@@ -49,6 +49,20 @@ TEST_CASE(literal_test)
EXPECT
(
l4
.
empty
());
}
TEST_CASE
(
literal_nstd_shape
)
{
migraphx
::
shape
nstd_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
},
{
12
,
1
,
6
,
3
}};
std
::
vector
<
float
>
nstd_data
(
12
);
std
::
iota
(
nstd_data
.
begin
(),
nstd_data
.
end
(),
0
);
migraphx
::
shape
std_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
}};
std
::
vector
<
float
>
std_data
=
{
0
,
3
,
6
,
9
,
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
};
auto
l0
=
migraphx
::
literal
{
nstd_shape
,
nstd_data
};
auto
l1
=
migraphx
::
literal
{
std_shape
,
std_data
};
EXPECT
(
l0
==
l1
);
}
TEST_CASE
(
literal_os1
)
{
migraphx
::
literal
l
{
1
};
...
...
test/ref_ops_test.cpp
View file @
0af471ff
...
...
@@ -835,24 +835,31 @@ TEST_CASE(concat_test)
}
}
TEST_CASE
(
contiguous_test
)
TEST_CASE
(
contiguous_
param_
test
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
},
{
12
,
1
,
6
,
3
}};
std
::
vector
<
float
>
data
(
12
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l
=
mm
->
add_
literal
(
migraphx
::
literal
{
a_shape
,
data
}
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l
);
auto
a
=
mm
->
add_
parameter
(
"X"
,
a_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
a
);
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
data
(
12
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
migraphx
::
parameter_map
params
;
params
[
"X"
]
=
migraphx
::
argument
(
a_shape
,
data
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
result
.
visit
([
&
](
auto
output
)
{
std
::
vector
<
size_t
>
new_strides
=
{
12
,
4
,
2
,
1
};
EXPECT
(
bool
{
output
.
get_shape
().
strides
()
==
new_strides
});
});
std
::
vector
<
float
>
results_vector
(
12
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
size_t
>
new_lens
=
{
1
,
3
,
2
,
2
};
std
::
vector
<
size_t
>
new_strides
=
{
12
,
1
,
6
,
3
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
data
));
std
::
vector
<
float
>
gold
=
{
0
,
3
,
6
,
9
,
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
conv_dynamic_batch_test
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment