Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f320a3da
"driver/conv_driver.cpp" did not exist on "d075adf12642815a0755823b4d268766a6c2346c"
Commit
f320a3da
authored
Jul 18, 2018
by
Paul
Browse files
Auto cast context
parent
29448044
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
15 deletions
+36
-15
src/include/migraph/operation.hpp
src/include/migraph/operation.hpp
+8
-1
src/targets/miopen/lowering.cpp
src/targets/miopen/lowering.cpp
+5
-10
tools/include/operation.hpp
tools/include/operation.hpp
+8
-1
tools/te.py
tools/te.py
+15
-3
No files found.
src/include/migraph/operation.hpp
View file @
f320a3da
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace
migraph
{
namespace
migraph
{
...
@@ -22,6 +23,12 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -22,6 +23,12 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
/*
/*
* Type-erased interface for:
* Type-erased interface for:
*
*
...
@@ -169,7 +176,7 @@ struct operation
...
@@ -169,7 +176,7 @@ struct operation
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
{
{
return
private_detail_te_value
.
compute
(
ctx
,
std
::
move
(
output
),
std
::
move
(
input
));
return
compute_op
(
private_detail_te_value
,
ctx
,
std
::
move
(
output
),
std
::
move
(
input
));
}
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
...
...
src/targets/miopen/lowering.cpp
View file @
f320a3da
...
@@ -25,9 +25,8 @@ struct miopen_convolution
...
@@ -25,9 +25,8 @@ struct miopen_convolution
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
@@ -77,9 +76,8 @@ struct miopen_pooling
...
@@ -77,9 +76,8 @@ struct miopen_pooling
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
@@ -110,7 +108,7 @@ struct miopen_add
...
@@ -110,7 +108,7 @@ struct miopen_add
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
if
(
args
[
1
].
get_shape
().
broadcasted
())
if
(
args
[
1
].
get_shape
().
broadcasted
())
{
{
...
@@ -127,7 +125,6 @@ struct miopen_add
...
@@ -127,7 +125,6 @@ struct miopen_add
}
}
else
else
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
auto
a_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
a_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
b_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
b_desc
=
make_tensor
(
args
[
1
].
get_shape
());
...
@@ -157,9 +154,8 @@ struct miopen_gemm
...
@@ -157,9 +154,8 @@ struct miopen_gemm
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1.0
f
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
...
@@ -196,9 +192,8 @@ struct miopen_relu
...
@@ -196,9 +192,8 @@ struct miopen_relu
return
inputs
.
at
(
1
);
return
inputs
.
at
(
1
);
}
}
argument
compute
(
mi
graph
::
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
mi
open_
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
tools/include/operation.hpp
View file @
f320a3da
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace
migraph
{
namespace
migraph
{
...
@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
<%
<%
interface
(
'
operation
'
,
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
)
)
%>
%>
...
...
tools/te.py
View file @
f320a3da
...
@@ -213,16 +213,21 @@ def internal_name(name):
...
@@ -213,16 +213,21 @@ def internal_name(name):
else
:
else
:
return
name
return
name
def
generate_call
(
m
,
friend
):
def
generate_call
(
m
,
friend
,
indirect
):
if
m
[
'name'
].
startswith
(
'operator'
):
if
m
[
'name'
].
startswith
(
'operator'
):
op
=
m
[
'name'
][
8
:]
op
=
m
[
'name'
][
8
:]
args
=
m
[
'args'
]
args
=
m
[
'args'
]
if
','
in
args
:
if
','
in
args
:
return
args
.
replace
(
','
,
op
)
return
args
.
replace
(
','
,
op
)
else
:
else
:
return
string
.
Template
(
'${op}${arg
a
}'
).
substitute
(
op
=
op
,
args
=
args
)
return
string
.
Template
(
'${op}${arg
s
}'
).
substitute
(
op
=
op
,
args
=
args
)
if
friend
:
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
if
indirect
:
if
m
[
'args'
]:
return
string
.
Template
(
'${default}(private_detail_te_value, ${args})'
).
substitute
(
m
)
else
:
return
string
.
Template
(
'${default}(private_detail_te_value)'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
def
convert_member
(
d
,
struct_name
):
def
convert_member
(
d
,
struct_name
):
...
@@ -242,9 +247,12 @@ def convert_member(d, struct_name):
...
@@ -242,9 +247,12 @@ def convert_member(d, struct_name):
member_params
=
[]
member_params
=
[]
skip
=
False
skip
=
False
friend
=
False
friend
=
False
indirect
=
False
if
'friend'
in
d
[
name
]:
if
'friend'
in
d
[
name
]:
friend
=
True
friend
=
True
skip
=
True
skip
=
True
if
'default'
in
d
[
name
]:
indirect
=
True
for
x
in
d
[
name
]:
for
x
in
d
[
name
]:
t
=
d
[
name
][
x
]
t
=
d
[
name
][
x
]
if
x
==
'return'
:
if
x
==
'return'
:
...
@@ -254,8 +262,12 @@ def convert_member(d, struct_name):
...
@@ -254,8 +262,12 @@ def convert_member(d, struct_name):
member
[
'member_const'
]
=
'const'
member
[
'member_const'
]
=
'const'
elif
x
==
'friend'
:
elif
x
==
'friend'
:
member
[
'friend'
]
=
'friend'
member
[
'friend'
]
=
'friend'
elif
x
==
'default'
:
member
[
'default'
]
=
t
elif
x
==
'using'
:
elif
x
==
'using'
:
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
elif
x
.
startswith
(
'__'
)
and
x
.
endswith
(
'__'
):
continue
else
:
else
:
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
arg_name
=
x
arg_name
=
x
...
@@ -278,7 +290,7 @@ def convert_member(d, struct_name):
...
@@ -278,7 +290,7 @@ def convert_member(d, struct_name):
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'member_params'
]
=
','
.
join
(
member_params
)
member
[
'member_params'
]
=
','
.
join
(
member_params
)
member
[
'call'
]
=
generate_call
(
member
,
friend
)
member
[
'call'
]
=
generate_call
(
member
,
friend
,
indirect
)
return
member
return
member
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment