Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3e46f548
Commit
3e46f548
authored
Nov 02, 2022
by
charlie
Browse files
Removed the automatic conversion to a standard shape
parent
0af471ff
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
45 deletions
+28
-45
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+1
-17
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+0
-25
test/literal_test.cpp
test/literal_test.cpp
+1
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+26
-2
No files found.
src/include/migraphx/literal.hpp
View file @
3e46f548
...
...
@@ -113,24 +113,8 @@ struct literal : raw_data<literal>
void
fill
(
Iterator
start
,
Iterator
end
)
{
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
if
(
m_shape
.
standard
())
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
}
else
{
// make the literal into a standard shape (contiguous)
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
shape_for_each_nstd
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
// NOLINT(bugprone-signed-char-misuse)
it
++
;
});
});
m_shape
=
{
m_shape
.
type
(),
m_shape
.
lens
()};
}
}
};
template
<
class
F
>
...
...
src/include/migraphx/shape_for_each.hpp
View file @
3e46f548
...
...
@@ -55,31 +55,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call
(
indices
);
}
}
/**
* Iterates the given function over the given shape indices.
* Will iterate using non-standard strides if given a non-standard shape.
*/
template
<
class
F
>
void
shape_for_each_nstd
(
const
migraphx
::
shape
&
s
,
F
f
)
{
// Ensure calls to f use const ref to vector
auto
call
=
[
&
f
](
const
std
::
vector
<
std
::
size_t
>&
i
)
{
f
(
i
);
};
std
::
vector
<
std
::
size_t
>
indices
(
s
.
lens
().
size
());
for
(
std
::
size_t
i
=
0
;
i
<
s
.
elements
();
i
++
)
{
std
::
transform
(
s
.
strides
().
begin
(),
s
.
strides
().
end
(),
s
.
lens
().
begin
(),
indices
.
begin
(),
[
&
](
std
::
size_t
stride
,
std
::
size_t
len
)
{
assert
(
len
>
0
and
stride
>
0
);
return
(
i
/
stride
)
%
len
;
});
call
(
indices
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
test/literal_test.cpp
View file @
3e46f548
...
...
@@ -60,7 +60,7 @@ TEST_CASE(literal_nstd_shape)
auto
l0
=
migraphx
::
literal
{
nstd_shape
,
nstd_data
};
auto
l1
=
migraphx
::
literal
{
std_shape
,
std_data
};
EXPECT
(
l0
=
=
l1
);
EXPECT
(
l0
!
=
l1
);
}
TEST_CASE
(
literal_os1
)
...
...
test/ref_ops_test.cpp
View file @
3e46f548
...
...
@@ -835,12 +835,36 @@ TEST_CASE(concat_test)
}
}
TEST_CASE
(
contiguous_
param_
test
)
TEST_CASE
(
contiguous_test
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
},
{
12
,
1
,
6
,
3
}};
std
::
vector
<
float
>
data
(
12
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l
=
mm
->
add_literal
(
migraphx
::
literal
{
a_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l
);
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
result
.
visit
([
&
](
auto
output
)
{
std
::
vector
<
size_t
>
new_strides
=
{
12
,
4
,
2
,
1
};
EXPECT
(
bool
{
output
.
get_shape
().
strides
()
==
new_strides
});
});
std
::
vector
<
float
>
results_vector
(
12
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
3
,
6
,
9
,
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
contiguous_param_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
},
{
12
,
1
,
6
,
3
}};
auto
a
=
mm
->
add_parameter
(
"X"
,
a_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
a
);
p
.
compile
(
migraphx
::
ref
::
target
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment