Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5f3f6e73
Commit
5f3f6e73
authored
Nov 14, 2022
by
charlie
Browse files
Detach from non-std literal PR
parent
2cf7ae45
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
27 deletions
+19
-27
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+15
-1
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+1
-4
test/literal_test.cpp
test/literal_test.cpp
+0
-14
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+3
-8
No files found.
src/include/migraphx/literal.hpp
View file @
5f3f6e73
...
@@ -111,7 +111,21 @@ struct literal : raw_data<literal>
...
@@ -111,7 +111,21 @@ struct literal : raw_data<literal>
void
fill
(
Iterator
start
,
Iterator
end
)
void
fill
(
Iterator
start
,
Iterator
end
)
{
{
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
if
(
m_shape
.
standard
())
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
}
else
{
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
// NOLINT(bugprone-signed-char-misuse)
it
++
;
});
});
}
}
}
};
};
...
...
src/include/migraphx/shape_for_each.hpp
View file @
5f3f6e73
...
@@ -31,10 +31,6 @@
...
@@ -31,10 +31,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Iterates the given function over the standard shape indices.
* Will iterate using standard strides if given a non-standard shape.
*/
template
<
class
F
>
template
<
class
F
>
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
{
{
...
@@ -55,6 +51,7 @@ void shape_for_each(const migraphx::shape& s, F f)
...
@@ -55,6 +51,7 @@ void shape_for_each(const migraphx::shape& s, F f)
call
(
indices
);
call
(
indices
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
test/literal_test.cpp
View file @
5f3f6e73
...
@@ -49,20 +49,6 @@ TEST_CASE(literal_test)
...
@@ -49,20 +49,6 @@ TEST_CASE(literal_test)
EXPECT
(
l4
.
empty
());
EXPECT
(
l4
.
empty
());
}
}
TEST_CASE
(
literal_nstd_shape
)
{
migraphx
::
shape
nstd_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
},
{
12
,
1
,
6
,
3
}};
std
::
vector
<
float
>
nstd_data
(
12
);
std
::
iota
(
nstd_data
.
begin
(),
nstd_data
.
end
(),
0
);
migraphx
::
shape
std_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
}};
std
::
vector
<
float
>
std_data
=
{
0
,
3
,
6
,
9
,
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
};
auto
l0
=
migraphx
::
literal
{
nstd_shape
,
nstd_data
};
auto
l1
=
migraphx
::
literal
{
std_shape
,
std_data
};
EXPECT
(
l0
!=
l1
);
}
TEST_CASE
(
literal_os1
)
TEST_CASE
(
literal_os1
)
{
{
migraphx
::
literal
l
{
1
};
migraphx
::
literal
l
{
1
};
...
...
test/ref_ops_test.cpp
View file @
5f3f6e73
...
@@ -848,16 +848,11 @@ TEST_CASE(contiguous_test)
...
@@ -848,16 +848,11 @@ TEST_CASE(contiguous_test)
p.compile(migraphx::ref::target{});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto result = p.eval({}).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<
float> gold = {0, 3, 6, 9, 1, 4, 7
, 1
0
,
2
,
5, 8, 11
};
std::vector<
size_t> new_strides = {12
, 1,
6
,
3
};
EXPECT(migraphx::verify_range(results_vector,
gold
));
EXPECT(migraphx::verify_range(results_vector,
data
));
}
}
TEST_CASE(contiguous_param_test)
TEST_CASE(contiguous_param_test)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment