Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f320a3da
Commit
f320a3da
authored
Jul 18, 2018
by
Paul
Browse files
Auto cast context
parent
29448044
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
15 deletions
+36
-15
src/include/migraph/operation.hpp
src/include/migraph/operation.hpp
+8
-1
src/targets/miopen/lowering.cpp
src/targets/miopen/lowering.cpp
+5
-10
tools/include/operation.hpp
tools/include/operation.hpp
+8
-1
tools/te.py
tools/te.py
+15
-3
No files found.
src/include/migraph/operation.hpp
View file @
f320a3da
...
...
@@ -9,6 +9,7 @@
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace
migraph
{
...
...
@@ -22,6 +23,12 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
/*
* Type-erased interface for:
*
...
...
@@ -169,7 +176,7 @@ struct operation
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
{
return
private_detail_te_value
.
compute
(
ctx
,
std
::
move
(
output
),
std
::
move
(
input
));
return
compute_op
(
private_detail_te_value
,
ctx
,
std
::
move
(
output
),
std
::
move
(
input
));
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
...
...
src/targets/miopen/lowering.cpp
View file @
f320a3da
...
...
@@ -25,9 +25,8 @@ struct miopen_convolution
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
@@ -77,9 +76,8 @@ struct miopen_pooling
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
@@ -110,7 +108,7 @@ struct miopen_add
return
inputs
.
at
(
0
);
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
if
(
args
[
1
].
get_shape
().
broadcasted
())
{
...
...
@@ -127,7 +125,6 @@ struct miopen_add
}
else
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
auto
a_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
b_desc
=
make_tensor
(
args
[
1
].
get_shape
());
...
...
@@ -157,9 +154,8 @@ struct miopen_gemm
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
...
...
@@ -196,9 +192,8 @@ struct miopen_relu
return
inputs
.
at
(
1
);
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
tools/include/operation.hpp
View file @
f320a3da
...
...
@@ -9,6 +9,7 @@
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace
migraph
{
...
...
@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
<%
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
)
%>
...
...
tools/te.py
View file @
f320a3da
...
...
@@ -213,16 +213,21 @@ def internal_name(name):
else
:
return
name
def
generate_call
(
m
,
friend
):
def
generate_call
(
m
,
friend
,
indirect
):
if
m
[
'name'
].
startswith
(
'operator'
):
op
=
m
[
'name'
][
8
:]
args
=
m
[
'args'
]
if
','
in
args
:
return
args
.
replace
(
','
,
op
)
else
:
return
string
.
Template
(
'${op}${arg
a
}'
).
substitute
(
op
=
op
,
args
=
args
)
return
string
.
Template
(
'${op}${arg
s
}'
).
substitute
(
op
=
op
,
args
=
args
)
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
if
indirect
:
if
m
[
'args'
]:
return
string
.
Template
(
'${default}(private_detail_te_value, ${args})'
).
substitute
(
m
)
else
:
return
string
.
Template
(
'${default}(private_detail_te_value)'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
def
convert_member
(
d
,
struct_name
):
...
...
@@ -242,9 +247,12 @@ def convert_member(d, struct_name):
member_params
=
[]
skip
=
False
friend
=
False
indirect
=
False
if
'friend'
in
d
[
name
]:
friend
=
True
skip
=
True
if
'default'
in
d
[
name
]:
indirect
=
True
for
x
in
d
[
name
]:
t
=
d
[
name
][
x
]
if
x
==
'return'
:
...
...
@@ -254,8 +262,12 @@ def convert_member(d, struct_name):
member
[
'member_const'
]
=
'const'
elif
x
==
'friend'
:
member
[
'friend'
]
=
'friend'
elif
x
==
'default'
:
member
[
'default'
]
=
t
elif
x
==
'using'
:
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
elif
x
.
startswith
(
'__'
)
and
x
.
endswith
(
'__'
):
continue
else
:
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
arg_name
=
x
...
...
@@ -278,7 +290,7 @@ def convert_member(d, struct_name):
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'member_params'
]
=
','
.
join
(
member_params
)
member
[
'call'
]
=
generate_call
(
member
,
friend
)
member
[
'call'
]
=
generate_call
(
member
,
friend
,
indirect
)
return
member
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment