Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f320a3da
Commit
f320a3da
authored
Jul 18, 2018
by
Paul
Browse files
Auto cast context
parent
29448044
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
15 deletions
+36
-15
src/include/migraph/operation.hpp
src/include/migraph/operation.hpp
+8
-1
src/targets/miopen/lowering.cpp
src/targets/miopen/lowering.cpp
+5
-10
tools/include/operation.hpp
tools/include/operation.hpp
+8
-1
tools/te.py
tools/te.py
+15
-3
No files found.
src/include/migraph/operation.hpp
View file @
f320a3da
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace
migraph
{
namespace
migraph
{
...
@@ -22,6 +23,12 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -22,6 +23,12 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
/*
/*
* Type-erased interface for:
* Type-erased interface for:
*
*
...
@@ -169,7 +176,7 @@ struct operation
...
@@ -169,7 +176,7 @@ struct operation
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
{
{
return
private_detail_te_value
.
compute
(
ctx
,
std
::
move
(
output
),
std
::
move
(
input
));
return
compute_op
(
private_detail_te_value
,
ctx
,
std
::
move
(
output
),
std
::
move
(
input
));
}
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
...
...
src/targets/miopen/lowering.cpp
View file @
f320a3da
...
@@ -25,9 +25,8 @@ struct miopen_convolution
...
@@ -25,9 +25,8 @@ struct miopen_convolution
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
@@ -77,9 +76,8 @@ struct miopen_pooling
...
@@ -77,9 +76,8 @@ struct miopen_pooling
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
@@ -110,7 +108,7 @@ struct miopen_add
...
@@ -110,7 +108,7 @@ struct miopen_add
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
if
(
args
[
1
].
get_shape
().
broadcasted
())
if
(
args
[
1
].
get_shape
().
broadcasted
())
{
{
...
@@ -127,7 +125,6 @@ struct miopen_add
...
@@ -127,7 +125,6 @@ struct miopen_add
}
}
else
else
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
auto
a_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
a_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
b_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
b_desc
=
make_tensor
(
args
[
1
].
get_shape
());
...
@@ -157,9 +154,8 @@ struct miopen_gemm
...
@@ -157,9 +154,8 @@ struct miopen_gemm
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1.0
f
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
...
@@ -196,9 +192,8 @@ struct miopen_relu
...
@@ -196,9 +192,8 @@ struct miopen_relu
return
inputs
.
at
(
1
);
return
inputs
.
at
(
1
);
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
tools/include/operation.hpp
View file @
f320a3da
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace
migraph
{
namespace
migraph
{
...
@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
<%
<%
interface
(
'
operation
'
,
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
)
)
%>
%>
...
...
tools/te.py
View file @
f320a3da
...
@@ -213,16 +213,21 @@ def internal_name(name):
...
@@ -213,16 +213,21 @@ def internal_name(name):
else
:
else
:
return
name
return
name
def
generate_call
(
m
,
friend
):
def
generate_call
(
m
,
friend
,
indirect
):
if
m
[
'name'
].
startswith
(
'operator'
):
if
m
[
'name'
].
startswith
(
'operator'
):
op
=
m
[
'name'
][
8
:]
op
=
m
[
'name'
][
8
:]
args
=
m
[
'args'
]
args
=
m
[
'args'
]
if
','
in
args
:
if
','
in
args
:
return
args
.
replace
(
','
,
op
)
return
args
.
replace
(
','
,
op
)
else
:
else
:
return
string
.
Template
(
'${op}${arg
a
}'
).
substitute
(
op
=
op
,
args
=
args
)
return
string
.
Template
(
'${op}${arg
s
}'
).
substitute
(
op
=
op
,
args
=
args
)
if
friend
:
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
if
indirect
:
if
m
[
'args'
]:
return
string
.
Template
(
'${default}(private_detail_te_value, ${args})'
).
substitute
(
m
)
else
:
return
string
.
Template
(
'${default}(private_detail_te_value)'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
def
convert_member
(
d
,
struct_name
):
def
convert_member
(
d
,
struct_name
):
...
@@ -242,9 +247,12 @@ def convert_member(d, struct_name):
...
@@ -242,9 +247,12 @@ def convert_member(d, struct_name):
member_params
=
[]
member_params
=
[]
skip
=
False
skip
=
False
friend
=
False
friend
=
False
indirect
=
False
if
'friend'
in
d
[
name
]:
if
'friend'
in
d
[
name
]:
friend
=
True
friend
=
True
skip
=
True
skip
=
True
if
'default'
in
d
[
name
]:
indirect
=
True
for
x
in
d
[
name
]:
for
x
in
d
[
name
]:
t
=
d
[
name
][
x
]
t
=
d
[
name
][
x
]
if
x
==
'return'
:
if
x
==
'return'
:
...
@@ -254,8 +262,12 @@ def convert_member(d, struct_name):
...
@@ -254,8 +262,12 @@ def convert_member(d, struct_name):
member
[
'member_const'
]
=
'const'
member
[
'member_const'
]
=
'const'
elif
x
==
'friend'
:
elif
x
==
'friend'
:
member
[
'friend'
]
=
'friend'
member
[
'friend'
]
=
'friend'
elif
x
==
'default'
:
member
[
'default'
]
=
t
elif
x
==
'using'
:
elif
x
==
'using'
:
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
elif
x
.
startswith
(
'__'
)
and
x
.
endswith
(
'__'
):
continue
else
:
else
:
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
arg_name
=
x
arg_name
=
x
...
@@ -278,7 +290,7 @@ def convert_member(d, struct_name):
...
@@ -278,7 +290,7 @@ def convert_member(d, struct_name):
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'member_params'
]
=
','
.
join
(
member_params
)
member
[
'member_params'
]
=
','
.
join
(
member_params
)
member
[
'call'
]
=
generate_call
(
member
,
friend
)
member
[
'call'
]
=
generate_call
(
member
,
friend
,
indirect
)
return
member
return
member
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment