Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4f07b8f1
Commit
4f07b8f1
authored
Apr 11, 2022
by
Shucai Xiao
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into test_branch_for_ort2
parents
af110526
1e0bbd78
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
16 deletions
+59
-16
tools/api/api.cpp
tools/api/api.cpp
+34
-0
tools/include/schedule_model.hpp
tools/include/schedule_model.hpp
+6
-6
tools/te.py
tools/te.py
+19
-10
No files found.
tools/api/api.cpp
View file @
4f07b8f1
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <migraphx/ref/target.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <algorithm>
...
@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; }
...
@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void
print_module
(
const
module
&
m
)
{
std
::
cout
<<
m
<<
std
::
endl
;
}
void
print_module
(
const
module
&
m
)
{
std
::
cout
<<
m
<<
std
::
endl
;
}
struct
experimental_custom_op
{
std
::
string
name
;
experimental_custom_op
()
=
default
;
experimental_custom_op
(
std
::
string
pname
)
:
name
(
std
::
move
(
pname
))
{}
};
template
<
class
CustomOp
>
struct
custom_operation
{
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
,
F
)
{
return
pack
();
}
CustomOp
op
;
std
::
string
name
()
const
{
return
op
.
xobject
.
name
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
std
::
move
(
inputs
));
}
argument
compute
(
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPHX_THROW
(
"Not computable"
);
}
};
template
<
class
CustomOp
>
void
register_custom_op
(
const
CustomOp
&
op
)
{
register_op
(
custom_operation
<
CustomOp
>
{
op
});
}
migraphx
::
context
get_context
(
const
program
&
p
)
{
return
p
.
get_context
();
}
migraphx
::
context
get_context
(
const
program
&
p
)
{
return
p
.
get_context
();
}
}
// namespace migraphx
}
// namespace migraphx
...
...
tools/include/schedule_model.hpp
View file @
4f07b8f1
...
@@ -26,11 +26,11 @@ struct schedule_model
...
@@ -26,11 +26,11 @@ struct schedule_model
/// Get the number of concurrent instruction allowed
/// Get the number of concurrent instruction allowed
std
::
size_t
concurrency
()
const
;
std
::
size_t
concurrency
()
const
;
/// Schedule a concurrent instruction
/// Schedule a concurrent instruction
void
sched
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
void
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
// Insert necessary waits before an instruction
// Insert necessary waits before an instruction
void
wait
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
void
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
// Insert necessary records after an instruction
// Insert necessary records after an instruction
void
record
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
void
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
/// Compute weights for an operation
/// Compute weights for an operation
std
::
size_t
weight
(
const
operation
&
op
)
const
;
std
::
size_t
weight
(
const
operation
&
op
)
const
;
};
};
...
@@ -40,9 +40,9 @@ struct schedule_model
...
@@ -40,9 +40,9 @@ struct schedule_model
<%
<%
interface
(
'
schedule_model
'
,
interface
(
'
schedule_model
'
,
virtual
(
'
concurrency
'
,
returns
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
concurrency
'
,
returns
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
sched
'
,
p
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
n
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
sched
'
,
m
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
n
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
wait
'
,
p
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
wait_id
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
wait
'
,
m
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
wait_id
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
record
'
,
p
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
wait_id
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
record
'
,
m
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
wait_id
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
weight
'
,
returns
=
'
std
::
size_t
'
,
op
=
'
const
operation
&
'
,
const
=
True
)
virtual
(
'
weight
'
,
returns
=
'
std
::
size_t
'
,
op
=
'
const
operation
&
'
,
const
=
True
)
)
)
%>
%>
...
...
tools/te.py
View file @
4f07b8f1
...
@@ -12,16 +12,15 @@ headers = '''
...
@@ -12,16 +12,15 @@ headers = '''
'''
'''
form
=
string
.
Template
(
'''
form
=
string
.
Template
(
'''
#ifdef TYPE_ERASED_DECLARATION
/*
// Type-erased interface for:
* Type-erased interface for:
struct ${struct_name}
*
{
* struct ${struct_name}
${decl_members}
* {
};
${comment_members}
* };
#else
*
*/
struct ${struct_name}
struct ${struct_name}
{
{
...
@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x)
...
@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x)
if (y == nullptr) throw std::bad_cast();
if (y == nullptr) throw std::bad_cast();
return *y;
return *y;
}
}
#endif
'''
)
'''
)
nonvirtual_member
=
string
.
Template
(
'''
nonvirtual_member
=
string
.
Template
(
'''
...
@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
...
@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
comment_member
=
string
.
Template
(
comment_member
=
string
.
Template
(
'''* ${friend} ${return_type} ${name}(${params}) ${const};'''
)
'''* ${friend} ${return_type} ${name}(${params}) ${const};'''
)
decl_member
=
string
.
Template
(
''' ${comment}
${friend} ${return_type} ${name}(${params}) ${const};
'''
)
default_member
=
string
.
Template
(
'''
default_member
=
string
.
Template
(
'''
template<class T>
template<class T>
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
...
@@ -279,7 +283,8 @@ def convert_member(d, struct_name):
...
@@ -279,7 +283,8 @@ def convert_member(d, struct_name):
'this'
:
'(*this)'
,
'this'
:
'(*this)'
,
'using'
:
''
,
'using'
:
''
,
'brief'
:
''
,
'brief'
:
''
,
'return_'
:
''
'return_'
:
''
,
'comment'
:
'// '
}
}
args
=
[]
args
=
[]
params
=
[]
params
=
[]
...
@@ -306,6 +311,7 @@ def convert_member(d, struct_name):
...
@@ -306,6 +311,7 @@ def convert_member(d, struct_name):
member
[
'friend'
]
=
'friend'
member
[
'friend'
]
=
'friend'
elif
x
==
'default'
:
elif
x
==
'default'
:
member
[
'default'
]
=
t
member
[
'default'
]
=
t
member
[
'comment'
]
=
member
[
'comment'
]
+
'(optional)'
elif
x
==
'using'
:
elif
x
==
'using'
:
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
elif
x
==
'__brief__'
:
elif
x
==
'__brief__'
:
...
@@ -347,18 +353,21 @@ def generate_form(name, members):
...
@@ -347,18 +353,21 @@ def generate_form(name, members):
virtual_members
=
[]
virtual_members
=
[]
comment_members
=
[]
comment_members
=
[]
default_members
=
[]
default_members
=
[]
decl_members
=
[]
for
member
in
members
:
for
member
in
members
:
m
=
convert_member
(
member
,
name
)
m
=
convert_member
(
member
,
name
)
nonvirtual_members
.
append
(
nonvirtual_member
.
substitute
(
m
))
nonvirtual_members
.
append
(
nonvirtual_member
.
substitute
(
m
))
pure_virtual_members
.
append
(
pure_virtual_member
.
substitute
(
m
))
pure_virtual_members
.
append
(
pure_virtual_member
.
substitute
(
m
))
virtual_members
.
append
(
virtual_member
.
substitute
(
m
))
virtual_members
.
append
(
virtual_member
.
substitute
(
m
))
comment_members
.
append
(
comment_member
.
substitute
(
m
))
comment_members
.
append
(
comment_member
.
substitute
(
m
))
decl_members
.
append
(
decl_member
.
substitute
(
m
))
if
'default'
in
m
:
if
'default'
in
m
:
default_members
.
append
(
default_member
.
substitute
(
m
))
default_members
.
append
(
default_member
.
substitute
(
m
))
return
form
.
substitute
(
nonvirtual_members
=
''
.
join
(
nonvirtual_members
),
return
form
.
substitute
(
nonvirtual_members
=
''
.
join
(
nonvirtual_members
),
pure_virtual_members
=
''
.
join
(
pure_virtual_members
),
pure_virtual_members
=
''
.
join
(
pure_virtual_members
),
virtual_members
=
''
.
join
(
virtual_members
),
virtual_members
=
''
.
join
(
virtual_members
),
default_members
=
''
.
join
(
default_members
),
default_members
=
''
.
join
(
default_members
),
decl_members
=
''
.
join
(
decl_members
),
comment_members
=
'
\n
'
.
join
(
comment_members
),
comment_members
=
'
\n
'
.
join
(
comment_members
),
struct_name
=
name
)
struct_name
=
name
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment