Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
11c7cee5
Commit
11c7cee5
authored
Feb 09, 2019
by
Paul
Browse files
Support buffer info constructor
parent
8a3d1d09
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
3 deletions
+62
-3
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+24
-0
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+38
-3
No files found.
src/include/migraphx/shape.hpp
View file @
11c7cee5
...
...
@@ -61,6 +61,16 @@ struct shape
shape
(
type_t
t
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{}
template
<
class
Range1
,
class
Range2
>
shape
(
type_t
t
,
const
Range1
&
l
,
const
Range2
&
s
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()),
std
::
vector
<
std
::
size_t
>
(
s
.
begin
(),
s
.
end
()))
{}
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
...
...
@@ -141,6 +151,11 @@ struct shape
{
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
}
type_t
type_enum
()
const
{
return
get_type
<
T
>
{};
}
};
template
<
class
Visitor
>
...
...
@@ -156,6 +171,15 @@ struct shape
MIGRAPHX_THROW
(
"Unknown type"
);
}
template
<
class
Visitor
>
static
void
visit_types
(
Visitor
v
)
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) \
v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
private:
std
::
shared_ptr
<
const
shape_impl
>
impl
;
...
...
src/py/migraphx_py.cpp
View file @
11c7cee5
...
...
@@ -14,7 +14,7 @@
namespace
py
=
pybind11
;
template
<
class
F
>
struct
skip
_half
struct
throw
_half
{
F
f
;
...
...
@@ -30,10 +30,31 @@ struct skip_half
}
};
template
<
class
F
>
struct
skip_half
{
F
f
;
template
<
class
A
>
void
operator
()(
A
a
)
const
{
f
(
a
);
}
void
operator
()(
migraphx
::
shape
::
as
<
migraphx
::
half
>
)
const
{}
};
template
<
class
F
>
void
visit_type
(
const
migraphx
::
shape
&
s
,
F
f
)
{
s
.
visit_type
(
skip_half
<
F
>
{
f
});
s
.
visit_type
(
throw_half
<
F
>
{
f
});
}
template
<
class
F
>
void
visit_types
(
F
f
)
{
migraphx
::
shape
::
visit_types
(
skip_half
<
F
>
{
f
});
}
template
<
class
T
>
...
...
@@ -52,6 +73,16 @@ py::buffer_info to_buffer_info(T& x)
return
b
;
}
migraphx
::
shape
to_shape
(
const
py
::
buffer_info
&
info
)
{
migraphx
::
shape
::
type_t
t
;
visit_types
([
&
](
auto
as
)
{
if
(
info
.
format
==
py
::
format_descriptor
<
decltype
(
as
())
>::
format
())
t
=
as
.
type_enum
();
});
return
migraphx
::
shape
{
t
,
info
.
shape
,
info
.
strides
};
}
PYBIND11_MODULE
(
migraphx
,
m
)
{
py
::
class_
<
migraphx
::
shape
>
(
m
,
"shape"
)
...
...
@@ -70,7 +101,11 @@ PYBIND11_MODULE(migraphx, m)
.
def
(
"__repr__"
,
[](
const
migraphx
::
shape
&
s
)
{
return
migraphx
::
to_string
(
s
);
});
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
});
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def
(
"__init__"
,
[](
migraphx
::
argument
&
x
,
py
::
buffer
b
)
{
py
::
buffer_info
info
=
b
.
request
();
new
(
&
x
)
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
});
py
::
class_
<
migraphx
::
target
>
(
m
,
"target"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment