Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
faefeef9
Unverified
Commit
faefeef9
authored
May 25, 2022
by
Charlie Lin
Committed by
GitHub
May 25, 2022
Browse files
Merge branch 'develop' into dyn_shape_update
parents
97a40ac3
bf0a4713
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
107 additions
and
83 deletions
+107
-83
doc/src/reference/py.rst
doc/src/reference/py.rst
+7
-0
examples/nlp/python_bert_squad/BERT-Squad.ipynb
examples/nlp/python_bert_squad/BERT-Squad.ipynb
+1
-1
examples/nlp/python_bert_squad/README.md
examples/nlp/python_bert_squad/README.md
+1
-1
src/adjust_allocation.cpp
src/adjust_allocation.cpp
+6
-6
src/analyze_streams.cpp
src/analyze_streams.cpp
+14
-14
src/api/api.cpp
src/api/api.cpp
+16
-0
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+5
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+9
-0
src/api/migraphx.py
src/api/migraphx.py
+3
-0
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+8
-8
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+6
-22
src/eliminate_allocation.cpp
src/eliminate_allocation.cpp
+4
-4
src/eliminate_common_subexpression.cpp
src/eliminate_common_subexpression.cpp
+5
-5
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+6
-6
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+4
-4
src/eliminate_identity.cpp
src/eliminate_identity.cpp
+8
-8
src/include/migraphx/adjust_allocation.hpp
src/include/migraphx/adjust_allocation.hpp
+1
-1
src/include/migraphx/analyze_streams.hpp
src/include/migraphx/analyze_streams.hpp
+1
-1
src/include/migraphx/auto_contiguous.hpp
src/include/migraphx/auto_contiguous.hpp
+1
-1
src/include/migraphx/check_context.hpp
src/include/migraphx/check_context.hpp
+1
-1
No files found.
doc/src/reference/py.rst
View file @
faefeef9
...
@@ -146,6 +146,13 @@ module
...
@@ -146,6 +146,13 @@ module
:param list[module] mod_args: optional list of module arguments to the operator.
:param list[module] mod_args: optional list of module arguments to the operator.
:rtype instruction
:rtype instruction
.. py:method:: add_literal(data)
Adds constant or literal data of provided shape into the module from python buffer which includes numpy array.
:param py::buffer data: Python buffer or numpy array
:rtype instruction
.. py:method:: add_parameter(name, shape)
.. py:method:: add_parameter(name, shape)
Adds a parameter to the module with provided name and shape.
Adds a parameter to the module with provided name and shape.
...
...
examples/nlp/python_bert_squad/BERT-Squad.ipynb
View file @
faefeef9
...
@@ -62,7 +62,7 @@
...
@@ -62,7 +62,7 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"!wget -nc https://github.com/onnx/models/
raw/master
/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx"
"!wget -nc https://github.com/onnx/models/
blob/main
/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx"
]
]
},
},
{
{
...
...
examples/nlp/python_bert_squad/README.md
View file @
faefeef9
...
@@ -23,7 +23,7 @@ unzip uncased_L-12_H-768_A-12.zip
...
@@ -23,7 +23,7 @@ unzip uncased_L-12_H-768_A-12.zip
```
```
5) Get BERT ONNX model (bertsquad-10.onnx):
5) Get BERT ONNX model (bertsquad-10.onnx):
```
```
wget https://github.com/onnx/models/
raw/master
/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
wget https://github.com/onnx/models/
blob/main
/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
```
```
6) Run the inference, it will compile and run the model on three questions and small data provided in
`inputs.json`
:
6) Run the inference, it will compile and run the model on three questions and small data provided in
`inputs.json`
:
```
```
...
...
src/adjust_allocation.cpp
View file @
faefeef9
...
@@ -8,9 +8,9 @@
...
@@ -8,9 +8,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
adjust_allocation
::
apply
(
module
&
p
)
const
void
adjust_allocation
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// skip instruction with no input
// skip instruction with no input
if
(
ins
->
inputs
().
empty
())
if
(
ins
->
inputs
().
empty
())
...
@@ -27,13 +27,13 @@ void adjust_allocation::apply(module& p) const
...
@@ -27,13 +27,13 @@ void adjust_allocation::apply(module& p) const
// of the instruction, reallocate and replace the previous one
// of the instruction, reallocate and replace the previous one
if
(
alias_ins
->
get_shape
()
==
ins
->
get_shape
())
if
(
alias_ins
->
get_shape
()
==
ins
->
get_shape
())
continue
;
continue
;
auto
alloc_ins
=
p
.
insert_instruction
(
ins
,
model
.
allocate
(
ins
->
get_shape
()));
auto
alloc_ins
=
m
.
insert_instruction
(
ins
,
model
.
allocate
(
ins
->
get_shape
()));
p
.
replace_instruction
(
alias_ins
,
alloc_ins
);
m
.
replace_instruction
(
alias_ins
,
alloc_ins
);
// If the memory is an output parameter then copy the memory to the parameter
// If the memory is an output parameter then copy the memory to the parameter
if
(
alias_ins
->
name
()
==
"@param"
)
if
(
alias_ins
->
name
()
==
"@param"
)
{
{
auto
copy
=
p
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
model
.
copy
()),
ins
,
alias_ins
);
auto
copy
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
model
.
copy
()),
ins
,
alias_ins
);
auto
tail
=
range
(
std
::
next
(
copy
),
p
.
end
());
auto
tail
=
range
(
std
::
next
(
copy
),
m
.
end
());
for
(
auto
i
:
iterator_for
(
tail
))
for
(
auto
i
:
iterator_for
(
tail
))
{
{
if
(
contains
(
i
->
inputs
(),
ins
))
if
(
contains
(
i
->
inputs
(),
ins
))
...
...
src/analyze_streams.cpp
View file @
faefeef9
...
@@ -14,31 +14,31 @@ bool happens_before(const std::vector<std::size_t>& e1, const std::vector<std::s
...
@@ -14,31 +14,31 @@ bool happens_before(const std::vector<std::size_t>& e1, const std::vector<std::s
not
std
::
equal
(
e1
.
begin
(),
e1
.
end
(),
e2
.
begin
(),
e2
.
end
(),
std
::
greater_equal
<>
{});
not
std
::
equal
(
e1
.
begin
(),
e1
.
end
(),
e2
.
begin
(),
e2
.
end
(),
std
::
greater_equal
<>
{});
}
}
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
p
,
const
stream_model
&
m
)
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
,
const
stream_model
&
strm
m
)
{
{
using
vector_clock
=
std
::
vector
<
std
::
size_t
>
;
using
vector_clock
=
std
::
vector
<
std
::
size_t
>
;
std
::
vector
<
stream_race
>
races
;
std
::
vector
<
stream_race
>
races
;
auto
nstream
=
m
.
get_nstream
();
auto
nstream
=
strm
m
.
get_nstream
();
std
::
vector
<
vector_clock
>
vclock
(
nstream
,
vector_clock
(
nstream
));
std
::
vector
<
vector_clock
>
vclock
(
nstream
,
vector_clock
(
nstream
));
std
::
unordered_map
<
instruction_ref
,
vector_clock
>
timestamp
;
std
::
unordered_map
<
instruction_ref
,
vector_clock
>
timestamp
;
std
::
unordered_map
<
std
::
size_t
,
vector_clock
>
events
;
std
::
unordered_map
<
std
::
size_t
,
vector_clock
>
events
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
not
m
.
has_stream
(
ins
))
if
(
not
strm
m
.
has_stream
(
ins
))
continue
;
continue
;
std
::
size_t
s
=
m
.
get_stream
(
ins
);
std
::
size_t
s
=
strm
m
.
get_stream
(
ins
);
assert
(
s
<
nstream
);
assert
(
s
<
nstream
);
assert
(
vclock
.
size
()
==
nstream
);
assert
(
vclock
.
size
()
==
nstream
);
assert
(
vclock
[
s
].
size
()
==
nstream
);
assert
(
vclock
[
s
].
size
()
==
nstream
);
if
(
m
.
is_record
(
ins
))
if
(
strm
m
.
is_record
(
ins
))
{
{
vclock
[
s
][
s
]
++
;
vclock
[
s
][
s
]
++
;
auto
event
=
m
.
get_event_id
(
ins
);
auto
event
=
strm
m
.
get_event_id
(
ins
);
events
[
event
]
=
vclock
[
s
];
events
[
event
]
=
vclock
[
s
];
}
}
else
if
(
m
.
is_wait
(
ins
))
else
if
(
strm
m
.
is_wait
(
ins
))
{
{
auto
event
=
m
.
get_event_id
(
ins
);
auto
event
=
strm
m
.
get_event_id
(
ins
);
if
(
not
contains
(
events
,
event
))
if
(
not
contains
(
events
,
event
))
MIGRAPHX_THROW
(
"Event is waited on before being recorded: "
+
MIGRAPHX_THROW
(
"Event is waited on before being recorded: "
+
std
::
to_string
(
event
));
std
::
to_string
(
event
));
...
@@ -57,21 +57,21 @@ std::vector<stream_race> analyze_streams(const module& p, const stream_model& m)
...
@@ -57,21 +57,21 @@ std::vector<stream_race> analyze_streams(const module& p, const stream_model& m)
}
}
timestamp
[
ins
]
=
vclock
[
s
];
timestamp
[
ins
]
=
vclock
[
s
];
}
}
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
not
m
.
has_stream
(
ins
))
if
(
not
strm
m
.
has_stream
(
ins
))
continue
;
continue
;
if
(
ins
->
inputs
().
empty
())
if
(
ins
->
inputs
().
empty
())
continue
;
continue
;
std
::
size_t
s
=
m
.
get_stream
(
ins
);
std
::
size_t
s
=
strm
m
.
get_stream
(
ins
);
// Find inputs from different streams
// Find inputs from different streams
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
fix
([
&
](
auto
self
,
auto
start
)
{
fix
([
&
](
auto
self
,
auto
start
)
{
for
(
auto
input
:
start
->
inputs
())
for
(
auto
input
:
start
->
inputs
())
{
{
if
(
not
m
.
has_stream
(
input
))
if
(
not
strm
m
.
has_stream
(
input
))
self
(
input
);
self
(
input
);
else
if
(
m
.
get_stream
(
input
)
!=
s
)
else
if
(
strm
m
.
get_stream
(
input
)
!=
s
)
inputs
.
push_back
(
input
);
inputs
.
push_back
(
input
);
}
}
})(
ins
);
})(
ins
);
...
...
src/api/api.cpp
View file @
faefeef9
...
@@ -1072,6 +1072,22 @@ migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
...
@@ -1072,6 +1072,22 @@ migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
(
(
module
->
object
).
add_literal
((
shape
->
object
),
(
buffer
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
extern
"C"
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_module_t
module
,
const
char
*
name
,
const
char
*
name
,
...
...
src/api/include/migraphx/migraphx.h
View file @
faefeef9
...
@@ -258,6 +258,11 @@ migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instructi
...
@@ -258,6 +258,11 @@ migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instructi
migraphx_instructions_t
args
,
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
migraphx_modules_t
module_refs
);
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
);
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_module_t
module
,
const
char
*
name
,
const
char
*
name
,
...
...
src/api/include/migraphx/migraphx.hpp
View file @
faefeef9
...
@@ -762,6 +762,15 @@ struct module
...
@@ -762,6 +762,15 @@ struct module
return
instruction
(
op_ins
,
own
{});
return
instruction
(
op_ins
,
own
{});
}
}
template
<
typename
T
>
instruction
add_literal
(
const
migraphx
::
shape
&
s
,
T
*
buffer
)
{
migraphx_instruction_t
literal_ins
;
const
auto
*
buffer_ptr
=
reinterpret_cast
<
const
char
*>
(
buffer
);
call
(
&
migraphx_module_add_literal
,
&
literal_ins
,
mm
.
get
(),
s
.
get_handle_ptr
(),
buffer_ptr
);
return
instruction
(
literal_ins
,
own
{});
}
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
{
migraphx_instruction_t
param_ins
;
migraphx_instruction_t
param_ins
;
...
...
src/api/migraphx.py
View file @
faefeef9
...
@@ -212,6 +212,9 @@ def module(h):
...
@@ -212,6 +212,9 @@ def module(h):
module_refs
=
'std::vector<migraphx::module*>'
),
module_refs
=
'std::vector<migraphx::module*>'
),
fname
=
'add_instruction'
,
fname
=
'add_instruction'
,
returns
=
'migraphx::instruction_ref'
)
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_literal'
,
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'const char*'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_parameter'
,
h
.
method
(
'add_parameter'
,
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
returns
=
'migraphx::instruction_ref'
)
returns
=
'migraphx::instruction_ref'
)
...
...
src/auto_contiguous.cpp
View file @
faefeef9
...
@@ -8,10 +8,10 @@
...
@@ -8,10 +8,10 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
auto_contiguous
::
apply
(
module
&
p
)
const
void
auto_contiguous
::
apply
(
module
&
m
)
const
{
{
std
::
string
key
=
"require_std_shape"
;
std
::
string
key
=
"require_std_shape"
;
for
(
auto
ins
:
reverse_iterator_for
(
p
))
for
(
auto
ins
:
reverse_iterator_for
(
m
))
{
{
auto
&&
attr
=
ins
->
get_operator
().
attributes
();
auto
&&
attr
=
ins
->
get_operator
().
attributes
();
if
((
attr
.
get
(
key
,
false
)))
if
((
attr
.
get
(
key
,
false
)))
...
@@ -23,18 +23,18 @@ void auto_contiguous::apply(module& p) const
...
@@ -23,18 +23,18 @@ void auto_contiguous::apply(module& p) const
{
{
return
in
;
return
in
;
}
}
return
p
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
return
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
in
);
});
});
if
(
new_args
!=
args
)
if
(
new_args
!=
args
)
{
{
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
}
}
}
}
}
}
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// for last instruction that is NOT a return
// for last instruction that is NOT a return
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
...
@@ -42,8 +42,8 @@ void auto_contiguous::apply(module& p) const
...
@@ -42,8 +42,8 @@ void auto_contiguous::apply(module& p) const
shape
s
=
ins
->
get_shape
();
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
{
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
p
.
replace_instruction
(
ins
,
c
);
m
.
replace_instruction
(
ins
,
c
);
}
}
}
}
}
}
...
...
src/dead_code_elimination.cpp
View file @
faefeef9
...
@@ -9,26 +9,6 @@
...
@@ -9,26 +9,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
,
class
Iterator
>
std
::
ptrdiff_t
bidistance
(
const
Range
&
r
,
Iterator
start
,
Iterator
last
)
{
auto
start_forward
=
start
;
auto
start_backwards
=
start
;
std
::
size_t
n
=
0
;
while
(
start_forward
!=
last
and
start_backwards
!=
last
)
{
n
++
;
if
(
start_forward
!=
r
.
end
())
start_forward
++
;
if
(
start_backwards
!=
r
.
begin
())
start_backwards
--
;
}
if
(
start_forward
==
last
)
return
n
;
else
return
-
n
;
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
module
&
m
)
const
void
dead_code_elimination
::
apply
(
module
&
m
)
const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
...
@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
continue
;
continue
;
assert
(
bidistance
(
m
,
i
,
last
)
>
0
);
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
std
::
unordered_set
<
instruction_ref
>
visited
;
fix
([
&
](
auto
self
,
auto
leaf
)
{
fix
([
&
](
auto
self
,
auto
leaf
)
{
if
(
not
m
.
has_instruction
(
leaf
))
if
(
not
m
.
has_instruction
(
leaf
))
return
;
return
;
if
(
leaf
->
outputs
().
empty
())
if
(
leaf
->
outputs
().
empty
())
{
{
// Dont visit inputs twice
if
(
not
visited
.
insert
(
leaf
).
second
)
return
;
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
leaf
->
inputs
().
end
());
leaf
->
inputs
().
end
());
leaf
->
clear_arguments
();
leaf
->
clear_arguments
();
assert
(
bi
distance
(
m
,
last
,
leaf
)
<
0
);
assert
(
std
::
distance
(
m
.
begin
(),
leaf
)
<
std
::
distance
(
m
.
begin
(),
last
)
);
assert
(
leaf
!=
ins
);
assert
(
leaf
!=
ins
);
if
(
leaf
->
name
()
!=
"@param"
)
if
(
leaf
->
name
()
!=
"@param"
)
m
.
move_instruction
(
leaf
,
m
.
end
());
m
.
move_instruction
(
leaf
,
m
.
end
());
...
...
src/eliminate_allocation.cpp
View file @
faefeef9
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_allocation
::
apply
(
module
&
p
)
const
void
eliminate_allocation
::
apply
(
module
&
m
)
const
{
{
assert
(
alignment
>
0
);
assert
(
alignment
>
0
);
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
std
::
vector
<
std
::
pair
<
instruction_ref
,
std
::
size_t
>>
allocs
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
allocation_op
)
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
continue
;
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
...
@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
}
}
if
(
n
>
0
)
if
(
n
>
0
)
{
{
auto
mem
=
p
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
auto
mem
=
m
.
add_parameter
(
"memory"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
pp
:
allocs
)
for
(
auto
&&
pp
:
allocs
)
{
{
auto
ins
=
pp
.
first
;
auto
ins
=
pp
.
first
;
auto
s
=
ins
->
get_shape
();
auto
s
=
ins
->
get_shape
();
auto
offset
=
pp
.
second
;
auto
offset
=
pp
.
second
;
p
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
}
}
}
}
...
...
src/eliminate_common_subexpression.cpp
View file @
faefeef9
...
@@ -11,7 +11,7 @@ namespace migraphx {
...
@@ -11,7 +11,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Range
>
template
<
class
Range
>
void
cse_range
(
module
&
p
,
Range
&&
r
)
void
cse_range
(
module
&
m
,
Range
&&
r
)
{
{
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_multimap
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
std
::
unordered_set
<
instruction_ref
>
processed_ins
;
...
@@ -30,24 +30,24 @@ void cse_range(module& p, Range&& r)
...
@@ -30,24 +30,24 @@ void cse_range(module& p, Range&& r)
continue
;
continue
;
if
(
*
eq
!=
*
ins
)
if
(
*
eq
!=
*
ins
)
continue
;
continue
;
p
.
replace_instruction
(
ins
,
eq
);
m
.
replace_instruction
(
ins
,
eq
);
processed_ins
.
emplace
(
ins
);
processed_ins
.
emplace
(
ins
);
std
::
vector
<
instruction_ref
>
outputs
;
std
::
vector
<
instruction_ref
>
outputs
;
std
::
copy_if
(
eq
->
outputs
().
begin
(),
std
::
copy_if
(
eq
->
outputs
().
begin
(),
eq
->
outputs
().
end
(),
eq
->
outputs
().
end
(),
std
::
back_inserter
(
outputs
),
std
::
back_inserter
(
outputs
),
[
&
](
auto
x
)
{
return
p
.
has_instruction
(
x
);
});
[
&
](
auto
x
)
{
return
m
.
has_instruction
(
x
);
});
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
return
std
::
distance
(
eq
,
x
)
<
std
::
distance
(
eq
,
y
);
});
});
cse_range
(
p
,
outputs
);
cse_range
(
m
,
outputs
);
}
}
instructions
.
emplace
(
ins
->
name
(),
ins
);
instructions
.
emplace
(
ins
->
name
(),
ins
);
}
}
}
}
void
eliminate_common_subexpression
::
apply
(
module
&
p
)
const
{
cse_range
(
p
,
iterator_for
(
p
));
}
void
eliminate_common_subexpression
::
apply
(
module
&
m
)
const
{
cse_range
(
m
,
iterator_for
(
m
));
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/eliminate_concat.cpp
View file @
faefeef9
...
@@ -13,9 +13,9 @@
...
@@ -13,9 +13,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_concat
::
apply
(
module
&
p
)
const
void
eliminate_concat
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Look for the concat operator
// Look for the concat operator
if
(
ins
->
name
()
!=
concat_opt
.
name
())
if
(
ins
->
name
()
!=
concat_opt
.
name
())
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
...
@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std
::
sort
(
sorted_allocations
.
begin
(),
std
::
sort
(
sorted_allocations
.
begin
(),
sorted_allocations
.
end
(),
sorted_allocations
.
end
(),
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
return
std
::
distance
(
p
.
begin
(),
x
)
<
std
::
distance
(
p
.
begin
(),
y
);
return
std
::
distance
(
m
.
begin
(),
x
)
<
std
::
distance
(
m
.
begin
(),
y
);
});
});
// Move "super" allocation to the front
// Move "super" allocation to the front
auto
first
=
sorted_allocations
.
front
();
auto
first
=
sorted_allocations
.
front
();
auto
super
=
p
.
move_instruction
(
last
,
first
);
auto
super
=
m
.
move_instruction
(
last
,
first
);
// Replace each allocation with a load
// Replace each allocation with a load
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
for
(
auto
alloc
:
allocations
)
for
(
auto
alloc
:
allocations
)
{
{
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
op
::
load
op
{
alloc
->
get_shape
(),
offset
};
p
.
replace_instruction
(
alloc
,
op
,
{
super
});
m
.
replace_instruction
(
alloc
,
op
,
{
super
});
offset
+=
alloc
->
get_shape
().
bytes
();
offset
+=
alloc
->
get_shape
().
bytes
();
}
}
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
p
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"identity"
),
args
);
}
}
}
}
}
}
...
...
src/eliminate_contiguous.cpp
View file @
faefeef9
...
@@ -69,9 +69,9 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -69,9 +69,9 @@ static bool try_compute_shape(instruction_ref ins,
return
try_compute_shape
(
ins
,
inputs
,
mods
);
return
try_compute_shape
(
ins
,
inputs
,
mods
);
}
}
void
eliminate_contiguous
::
apply
(
module
&
p
)
const
void
eliminate_contiguous
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// return instruction should have inputs with standard shape
// return instruction should have inputs with standard shape
if
(
ins
->
name
()
==
"@return"
)
if
(
ins
->
name
()
==
"@return"
)
...
@@ -96,8 +96,8 @@ void eliminate_contiguous::apply(module& p) const
...
@@ -96,8 +96,8 @@ void eliminate_contiguous::apply(module& p) const
auto
c
=
op
::
contiguous
{};
auto
c
=
op
::
contiguous
{};
auto
r
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
auto
r
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
auto
l
=
m
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
p
.
replace_instruction
(
arg
,
l
);
m
.
replace_instruction
(
arg
,
l
);
}
}
}
}
}
}
...
...
src/eliminate_identity.cpp
View file @
faefeef9
...
@@ -8,21 +8,21 @@
...
@@ -8,21 +8,21 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_identity
::
apply
(
module
&
p
)
const
void
eliminate_identity
::
apply
(
module
&
m
)
const
{
{
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Skip the first instruction, since we always process the previous
// Skip the first instruction, since we always process the previous
// instruction
// instruction
if
(
ins
==
p
.
begin
())
if
(
ins
==
m
.
begin
())
continue
;
continue
;
const
auto
i
=
std
::
prev
(
ins
);
const
auto
i
=
std
::
prev
(
ins
);
if
(
i
->
name
()
==
"identity"
)
if
(
i
->
name
()
==
"identity"
)
{
{
p
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
m
.
replace_instruction
(
i
,
i
->
inputs
().
front
());
p
.
move_instruction
(
i
,
p
.
end
());
m
.
move_instruction
(
i
,
m
.
end
());
}
}
if
(
ins
==
last
)
if
(
ins
==
last
)
{
{
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
const
instruction_ref
&
identity_input
=
ins
->
inputs
().
front
();
if
(
identity_input
->
outputs
().
size
()
==
1
)
if
(
identity_input
->
outputs
().
size
()
==
1
)
{
{
p
.
move_instruction
(
identity_input
,
i
);
m
.
move_instruction
(
identity_input
,
i
);
// since this is the last instruction, removing it only
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
// requires changing "last" and calling remove below
last
=
std
::
prev
(
last
);
last
=
std
::
prev
(
last
);
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
...
@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break
;
break
;
}
}
}
}
p
.
remove_instructions
(
std
::
next
(
last
),
p
.
end
());
m
.
remove_instructions
(
std
::
next
(
last
),
m
.
end
());
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/adjust_allocation.hpp
View file @
faefeef9
...
@@ -13,7 +13,7 @@ struct adjust_allocation
...
@@ -13,7 +13,7 @@ struct adjust_allocation
{
{
allocation_model
model
;
allocation_model
model
;
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
std
::
string
name
()
const
{
return
"adjust_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/analyze_streams.hpp
View file @
faefeef9
...
@@ -16,7 +16,7 @@ struct stream_race
...
@@ -16,7 +16,7 @@ struct stream_race
instruction_ref
before
;
instruction_ref
before
;
};
};
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
p
,
const
stream_model
&
m
);
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
,
const
stream_model
&
strm
m
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/auto_contiguous.hpp
View file @
faefeef9
...
@@ -13,7 +13,7 @@ struct module;
...
@@ -13,7 +13,7 @@ struct module;
struct
auto_contiguous
struct
auto_contiguous
{
{
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/check_context.hpp
View file @
faefeef9
...
@@ -33,7 +33,7 @@ struct check_context
...
@@ -33,7 +33,7 @@ struct check_context
};
};
std
::
string
name
()
const
{
return
"check_context"
;
}
std
::
string
name
()
const
{
return
"check_context"
;
}
void
apply
(
module
&
p
)
const
{
p
.
insert_instruction
(
p
.
begin
(),
op
{});
}
void
apply
(
module
&
m
)
const
{
m
.
insert_instruction
(
m
.
begin
(),
op
{});
}
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment