Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3032998d
"tools/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "41c0487ba2b561a44c3faf8086aa3c824f807568"
Commit
3032998d
authored
Feb 14, 2019
by
Paul
Browse files
Improve array test
parent
704752eb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
10 deletions
+70
-10
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+5
-0
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+21
-0
test/py/array.py
test/py/array.py
+44
-10
No files found.
src/include/migraphx/tensor_view.hpp
View file @
3032998d
...
@@ -124,6 +124,11 @@ struct tensor_view
...
@@ -124,6 +124,11 @@ struct tensor_view
return
m_data
+
this
->
size
();
return
m_data
+
this
->
size
();
}
}
std
::
vector
<
T
>
to_vector
()
const
{
return
std
::
vector
<
T
>
(
this
->
begin
(),
this
->
end
());
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
{
if
(
!
x
.
empty
())
if
(
!
x
.
empty
())
...
...
src/py/migraphx_py.cpp
View file @
3032998d
...
@@ -28,6 +28,11 @@ struct throw_half
...
@@ -28,6 +28,11 @@ struct throw_half
{
{
throw
std
::
runtime_error
(
"Half not supported in python yet."
);
throw
std
::
runtime_error
(
"Half not supported in python yet."
);
}
}
void
operator
()(
migraphx
::
tensor_view
<
migraphx
::
half
>
)
const
{
throw
std
::
runtime_error
(
"Half not supported in python yet."
);
}
};
};
template
<
class
F
>
template
<
class
F
>
...
@@ -42,6 +47,8 @@ struct skip_half
...
@@ -42,6 +47,8 @@ struct skip_half
}
}
void
operator
()(
migraphx
::
shape
::
as
<
migraphx
::
half
>
)
const
{}
void
operator
()(
migraphx
::
shape
::
as
<
migraphx
::
half
>
)
const
{}
void
operator
()(
migraphx
::
tensor_view
<
migraphx
::
half
>
)
const
{}
};
};
template
<
class
F
>
template
<
class
F
>
...
@@ -50,6 +57,12 @@ void visit_type(const migraphx::shape& s, F f)
...
@@ -50,6 +57,12 @@ void visit_type(const migraphx::shape& s, F f)
s
.
visit_type
(
throw_half
<
F
>
{
f
});
s
.
visit_type
(
throw_half
<
F
>
{
f
});
}
}
template
<
class
T
,
class
F
>
void
visit
(
const
migraphx
::
raw_data
<
T
>&
x
,
F
f
)
{
x
.
visit
(
throw_half
<
F
>
{
f
});
}
template
<
class
F
>
template
<
class
F
>
void
visit_types
(
F
f
)
void
visit_types
(
F
f
)
{
{
...
@@ -123,6 +136,14 @@ PYBIND11_MODULE(migraphx, m)
...
@@ -123,6 +136,14 @@ PYBIND11_MODULE(migraphx, m)
py
::
buffer_info
info
=
b
.
request
();
py
::
buffer_info
info
=
b
.
request
();
new
(
&
x
)
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
new
(
&
x
)
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
})
})
.
def
(
"get_shape"
,
&
migraphx
::
argument
::
get_shape
)
.
def
(
"tolist"
,
[](
migraphx
::
argument
&
x
)
{
py
::
list
l
{
x
.
get_shape
().
elements
()};
visit
(
x
,
[
&
](
auto
data
)
{
l
=
py
::
cast
(
data
.
to_vector
());
});
return
l
;
})
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
argument
>
{})
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
argument
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
argument
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
argument
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
argument
&
x
)
{
return
migraphx
::
to_string
(
x
);
});
.
def
(
"__repr__"
,
[](
const
migraphx
::
argument
&
x
)
{
return
migraphx
::
to_string
(
x
);
});
...
...
test/py/array.py
View file @
3032998d
import
migraphx
import
migraphx
,
struct
def
assert_eq
(
x
,
y
):
if
x
==
y
:
pass
else
:
raise
Exception
(
str
(
x
)
+
" != "
+
str
(
y
))
def
get_lens
(
m
):
return
list
(
m
.
shape
)
def
get_strides
(
m
):
return
[
s
/
m
.
itemsize
for
s
in
m
.
strides
]
def
read_float
(
b
,
index
):
return
struct
.
unpack_from
(
'f'
,
b
,
index
*
4
)[
0
]
def
check_list
(
a
,
b
):
l
=
a
.
tolist
()
for
i
in
range
(
len
(
l
)):
assert_eq
(
l
[
i
],
read_float
(
b
,
i
))
def
run
(
p
):
params
=
{}
for
key
,
value
in
p
.
get_parameter_shapes
().
items
():
params
[
key
]
=
migraphx
.
to_gpu
(
migraphx
.
generate_argument
(
value
))
return
migraphx
.
from_gpu
(
p
.
run
(
params
))
p
=
migraphx
.
parse_onnx
(
"conv_relu_maxpool.onnx"
)
p
=
migraphx
.
parse_onnx
(
"conv_relu_maxpool.onnx"
)
p
.
compile
(
migraphx
.
get_target
(
"gpu"
))
p
.
compile
(
migraphx
.
get_target
(
"gpu"
))
params
=
{}
for
key
,
value
in
p
.
get_parameter_shapes
().
items
():
params
[
key
]
=
migraphx
.
to_gpu
(
migraphx
.
generate_argument
(
value
))
r1
=
migraphx
.
from_gpu
(
p
.
run
(
params
))
r1
=
run
(
p
)
r2
=
migraphx
.
from_gpu
(
p
.
run
(
params
))
r2
=
run
(
p
)
assert_eq
(
r1
,
r2
)
assert_eq
(
r1
.
tolist
(),
r2
.
tolist
())
assert_eq
(
r1
.
tolist
()[
0
],
read_float
(
r1
,
0
))
m1
=
memoryview
(
r1
)
m2
=
memoryview
(
r2
)
assert_eq
(
r1
.
get_shape
().
elements
(),
reduce
(
lambda
x
,
y
:
x
*
y
,
get_lens
(
m1
),
1
))
assert_eq
(
r1
.
get_shape
().
lens
(),
get_lens
(
m1
))
assert_eq
(
r1
.
get_shape
().
strides
(),
get_strides
(
m1
))
assert
r1
==
r2
check_list
(
r1
,
m1
.
tobytes
())
q1
=
memoryview
(
r1
)
check_list
(
r2
,
m2
.
tobytes
())
q2
=
memoryview
(
r2
)
assert
q1
.
tobytes
()
==
q2
.
tobytes
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment