Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c9cd5e1f
Commit
c9cd5e1f
authored
Feb 09, 2019
by
Paul
Browse files
Formatting
parent
11c7cee5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
18 deletions
+16
-18
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+13
-14
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+3
-4
No files found.
src/include/migraphx/shape.hpp
View file @
c9cd5e1f
...
@@ -61,16 +61,19 @@ struct shape
...
@@ -61,16 +61,19 @@ struct shape
shape
(
type_t
t
);
shape
(
type_t
t
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{}
template
<
class
Range1
,
class
Range2
>
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
}
template
<
class
Range1
,
class
Range2
>
shape
(
type_t
t
,
const
Range1
&
l
,
const
Range2
&
s
)
shape
(
type_t
t
,
const
Range1
&
l
,
const
Range2
&
s
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()),
std
::
vector
<
std
::
size_t
>
(
s
.
begin
(),
s
.
end
()))
:
shape
(
t
,
{}
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()),
std
::
vector
<
std
::
size_t
>
(
s
.
begin
(),
s
.
end
()))
{
}
type_t
type
()
const
;
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
...
@@ -152,10 +155,7 @@ struct shape
...
@@ -152,10 +155,7 @@ struct shape
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
}
}
type_t
type_enum
()
const
type_t
type_enum
()
const
{
return
get_type
<
T
>
{};
}
{
return
get_type
<
T
>
{};
}
};
};
template
<
class
Visitor
>
template
<
class
Visitor
>
...
@@ -174,8 +174,7 @@ struct shape
...
@@ -174,8 +174,7 @@ struct shape
template
<
class
Visitor
>
template
<
class
Visitor
>
static
void
visit_types
(
Visitor
v
)
static
void
visit_types
(
Visitor
v
)
{
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) \
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) v(as<t>());
v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
}
...
...
src/py/migraphx_py.cpp
View file @
c9cd5e1f
...
@@ -41,8 +41,7 @@ struct skip_half
...
@@ -41,8 +41,7 @@ struct skip_half
f
(
a
);
f
(
a
);
}
}
void
operator
()(
migraphx
::
shape
::
as
<
migraphx
::
half
>
)
const
void
operator
()(
migraphx
::
shape
::
as
<
migraphx
::
half
>
)
const
{}
{}
};
};
template
<
class
F
>
template
<
class
F
>
...
@@ -77,7 +76,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
...
@@ -77,7 +76,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
{
{
migraphx
::
shape
::
type_t
t
;
migraphx
::
shape
::
type_t
t
;
visit_types
([
&
](
auto
as
)
{
visit_types
([
&
](
auto
as
)
{
if
(
info
.
format
==
py
::
format_descriptor
<
decltype
(
as
())
>::
format
())
if
(
info
.
format
==
py
::
format_descriptor
<
decltype
(
as
())
>::
format
())
t
=
as
.
type_enum
();
t
=
as
.
type_enum
();
});
});
return
migraphx
::
shape
{
t
,
info
.
shape
,
info
.
strides
};
return
migraphx
::
shape
{
t
,
info
.
shape
,
info
.
strides
};
...
@@ -104,7 +103,7 @@ PYBIND11_MODULE(migraphx, m)
...
@@ -104,7 +103,7 @@ PYBIND11_MODULE(migraphx, m)
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def
(
"__init__"
,
[](
migraphx
::
argument
&
x
,
py
::
buffer
b
)
{
.
def
(
"__init__"
,
[](
migraphx
::
argument
&
x
,
py
::
buffer
b
)
{
py
::
buffer_info
info
=
b
.
request
();
py
::
buffer_info
info
=
b
.
request
();
new
(
&
x
)
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
new
(
&
x
)
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
});
});
py
::
class_
<
migraphx
::
target
>
(
m
,
"target"
);
py
::
class_
<
migraphx
::
target
>
(
m
,
"target"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment