Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
faefeef9
Unverified
Commit
faefeef9
authored
May 25, 2022
by
Charlie Lin
Committed by
GitHub
May 25, 2022
Browse files
Merge branch 'develop' into dyn_shape_update
parents
97a40ac3
bf0a4713
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
174 additions
and
54 deletions
+174
-54
src/include/migraphx/eliminate_allocation.hpp
src/include/migraphx/eliminate_allocation.hpp
+1
-1
src/include/migraphx/eliminate_common_subexpression.hpp
src/include/migraphx/eliminate_common_subexpression.hpp
+1
-1
src/include/migraphx/eliminate_concat.hpp
src/include/migraphx/eliminate_concat.hpp
+1
-1
src/include/migraphx/eliminate_contiguous.hpp
src/include/migraphx/eliminate_contiguous.hpp
+1
-1
src/include/migraphx/eliminate_identity.hpp
src/include/migraphx/eliminate_identity.hpp
+1
-1
src/include/migraphx/make_op.hpp
src/include/migraphx/make_op.hpp
+13
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+68
-4
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+1
-1
src/include/migraphx/propagate_constant.hpp
src/include/migraphx/propagate_constant.hpp
+1
-1
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+1
-1
src/include/migraphx/rewrite_pooling.hpp
src/include/migraphx/rewrite_pooling.hpp
+1
-1
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+12
-13
src/include/migraphx/schedule.hpp
src/include/migraphx/schedule.hpp
+1
-1
src/include/migraphx/simplify_algebra.hpp
src/include/migraphx/simplify_algebra.hpp
+1
-1
src/include/migraphx/simplify_reshapes.hpp
src/include/migraphx/simplify_reshapes.hpp
+1
-1
src/make_op.cpp
src/make_op.cpp
+28
-7
src/onnx/parse_mean.cpp
src/onnx/parse_mean.cpp
+27
-11
src/opt/memory_coloring.cpp
src/opt/memory_coloring.cpp
+2
-2
src/propagate_constant.cpp
src/propagate_constant.cpp
+4
-4
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+8
-0
No files found.
src/include/migraphx/eliminate_allocation.hpp
View file @
faefeef9
...
...
@@ -19,7 +19,7 @@ struct eliminate_allocation
std
::
string
allocation_op
{};
std
::
size_t
alignment
=
32
;
std
::
string
name
()
const
{
return
"eliminate_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_common_subexpression.hpp
View file @
faefeef9
...
...
@@ -16,7 +16,7 @@ struct module;
struct
eliminate_common_subexpression
{
std
::
string
name
()
const
{
return
"eliminate_common_subexpression"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_concat.hpp
View file @
faefeef9
...
...
@@ -18,7 +18,7 @@ struct eliminate_concat
{
concat_optimization
concat_opt
;
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_contiguous.hpp
View file @
faefeef9
...
...
@@ -17,7 +17,7 @@ struct eliminate_contiguous
{
std
::
string
op_name
;
std
::
string
name
()
const
{
return
"eliminate_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_identity.hpp
View file @
faefeef9
...
...
@@ -18,7 +18,7 @@ struct module;
struct
eliminate_identity
{
std
::
string
name
()
const
{
return
"eliminate_identity"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/make_op.hpp
View file @
faefeef9
...
...
@@ -9,7 +9,19 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
);
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
);
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
);
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template
<
class
Value
>
operation
make_op
(
const
std
::
string
&
name
,
const
Value
&
v
)
{
return
make_op_from_value
(
name
,
v
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
faefeef9
...
...
@@ -156,6 +156,19 @@ struct id_matcher
}
};
// Forward declare class and constructors
template
<
class
M
>
struct
basic_matcher
;
template
<
class
M
>
basic_matcher
<
M
>
make_basic_matcher
(
M
m
);
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
);
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
);
/// The basic matcher provides the all_of composability of the matcher
template
<
class
M
>
struct
basic_matcher
...
...
@@ -167,8 +180,8 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto
mm
=
m
;
return
make_b
f
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
return
make_b
asic_fun
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
mm
.
match
(
ctx
,
ins
);
if
(
result
)
{
...
...
@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct
matcher_result
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
struct
instruction_container
{
instruction_container
()
=
default
;
instruction_container
(
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
x
)
:
ins_map
(
std
::
move
(
x
))
{
}
instruction_ref
operator
[](
const
std
::
string
&
name
)
const
{
auto
it
=
ins_map
.
find
(
name
);
if
(
it
==
ins_map
.
end
())
MIGRAPHX_THROW
(
"Accessing name that wasn't bound in matcher: "
+
name
);
return
it
->
second
;
}
auto
find
(
const
std
::
string
&
name
)
const
{
return
ins_map
.
find
(
name
);
}
auto
begin
()
const
{
return
ins_map
.
cbegin
();
}
auto
end
()
const
{
return
ins_map
.
cend
();
}
bool
has_instructions_in
(
const
module
&
mod
)
const
{
return
std
::
all_of
(
ins_map
.
begin
(),
ins_map
.
end
(),
[
&
](
auto
&&
p
)
{
return
mod
.
has_instruction
(
p
.
second
);
});
}
private:
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
ins_map
;
};
instruction_container
instructions
;
instruction_ref
result
;
};
...
...
@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
result
.
result
=
ins
;
result
.
instructions
=
ctx
.
instructions
;
assert
(
result
.
instructions
.
has_instructions_in
(
mod
));
}
else
{
...
...
@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
});
}
inline
auto
var
(
std
::
string
s
)
{
return
make_basic_fun_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
instruction_ref
)
->
optional
<
instruction_ref
>
{
auto
it
=
ctx
.
instructions
.
find
(
s
);
if
(
it
==
ctx
.
instructions
.
end
())
return
nullopt
;
return
it
->
second
;
});
}
inline
auto
name
(
std
::
string
s
)
{
return
make_basic_pred_matcher
(
...
...
@@ -696,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
))(
ms
...);
}
template
<
class
...
Ms
>
auto
skip_broadcasts_converts
(
Ms
...
ms
)
{
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
}
template
<
class
T
>
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
{
return
skip_broadcasts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
skip_broadcasts
_converts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
auto
l
=
ins
->
get_literal
();
...
...
src/include/migraphx/memory_coloring.hpp
View file @
faefeef9
...
...
@@ -17,7 +17,7 @@ struct memory_coloring
std
::
string
allocation_op
{};
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory coloring"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/propagate_constant.hpp
View file @
faefeef9
...
...
@@ -15,7 +15,7 @@ struct module;
struct
propagate_constant
{
std
::
string
name
()
const
{
return
"propagate_constant"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_batchnorm.hpp
View file @
faefeef9
...
...
@@ -16,7 +16,7 @@ struct module;
struct
rewrite_batchnorm
{
std
::
string
name
()
const
{
return
"rewrite_batchnorm"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_pooling.hpp
View file @
faefeef9
...
...
@@ -15,7 +15,7 @@ struct module;
struct
rewrite_pooling
{
std
::
string
name
()
const
{
return
"rewrite_pooling"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
faefeef9
...
...
@@ -19,22 +19,22 @@ struct module;
struct
rewrite_rnn
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
private:
// for vanilla rnn operators
void
apply_vanilla_rnn
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_vanilla_rnn
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_gru
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
...
...
@@ -44,9 +44,9 @@ struct rewrite_rnn
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
void
apply_lstm
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_lstm
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
...
...
@@ -55,24 +55,23 @@ struct rewrite_rnn
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
bool
is_variable_seq_lens
(
const
module
&
prog
,
instruction_ref
seq_lens
)
const
;
instruction_ref
replace_last_hs_output
(
module
&
prog
,
bool
is_variable_seq_lens
(
const
module
&
m
,
instruction_ref
seq_lens
)
const
;
instruction_ref
replace_last_hs_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
last_hs_output
,
op
::
rnn_direction
dirct
)
const
;
void
replace_last_cell_output
(
module
&
prog
,
void
replace_last_cell_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
cell_outputs
,
instruction_ref
last_cell_output
,
op
::
rnn_direction
dirct
)
const
;
std
::
size_t
get_seq_len
(
const
module
&
prog
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
std
::
size_t
get_seq_len
(
const
module
&
m
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
instruction_ref
pad_hidden_states
(
module
&
prog
,
instruction_ref
pad_hidden_states
(
module
&
m
,
instruction_ref
seq
,
instruction_ref
seq_lens
,
instruction_ref
hs
)
const
;
...
...
src/include/migraphx/schedule.hpp
View file @
faefeef9
...
...
@@ -19,7 +19,7 @@ struct schedule
schedule_model
model
{};
bool
enable
=
true
;
std
::
string
name
()
const
{
return
"schedule"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/simplify_algebra.hpp
View file @
faefeef9
...
...
@@ -15,7 +15,7 @@ struct module;
struct
simplify_algebra
{
std
::
string
name
()
const
{
return
"simplify_algebra"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/simplify_reshapes.hpp
View file @
faefeef9
...
...
@@ -16,7 +16,7 @@ struct module;
struct
simplify_reshapes
{
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/make_op.cpp
100755 → 100644
View file @
faefeef9
...
...
@@ -5,20 +5,41 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
)
{
return
load_op
(
name
);
}
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
)
template
<
class
F
>
operation
make_op_generic
(
const
std
::
string
&
name
,
F
for_each
)
{
if
(
not
(
v
.
is_object
()
or
(
v
.
empty
()
and
v
.
is_array
())))
MIGRAPHX_THROW
(
"Value is not an object"
);
auto
op
=
load_op
(
name
);
// Merge values
value
w
=
op
.
to_value
();
for
(
auto
&&
x
:
v
)
{
w
.
at
(
x
.
get_key
())
=
x
.
without_key
();
}
for_each
([
&
](
const
auto
&
key
,
const
auto
&
x
)
{
if
(
not
w
.
contains
(
key
))
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
MIGRAPHX_THROW
(
"No key '"
+
key
+
"' in "
+
name
);
w
.
at
(
key
)
=
x
;
});
op
.
from_value
(
w
);
return
op
;
}
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
)
{
return
make_op_generic
(
name
,
[
&
](
auto
f
)
{
for
(
auto
&&
[
key
,
x
]
:
v
)
f
(
key
,
x
);
});
}
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
)
{
if
(
not
(
v
.
is_object
()
or
(
v
.
empty
()
and
v
.
is_array
())))
MIGRAPHX_THROW
(
"Value is not an object for make_op: "
+
name
);
return
make_op_generic
(
name
,
[
&
](
auto
f
)
{
for
(
auto
&&
x
:
v
)
f
(
x
.
get_key
(),
x
.
without_key
());
});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_mean.cpp
View file @
faefeef9
...
...
@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -9,6 +10,9 @@ namespace onnx {
struct
parse_mean
:
op_parser
<
parse_mean
>
{
const
std
::
set
<
shape
::
type_t
>
float_types
=
{
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
};
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Mean"
}};
}
/// Calculates the element-wise mean of n>=1 input tensors
...
...
@@ -24,17 +28,29 @@ struct parse_mean : op_parser<parse_mean>
auto
divisor
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
args
[
0
]
->
get_shape
().
type
()},
{
num_data
}});
// TODO: Only divide when using floating-point
return
std
::
accumulate
(
args
.
begin
()
+
1
,
args
.
end
(),
info
.
add_broadcastable_binary_op
(
"div"
,
args
[
0
],
divisor
),
[
&
](
auto
mean
,
auto
data_i
)
{
// Pre-divide each tensor element-wise by n to reduce risk of
// overflow during summation
auto
div
=
info
.
add_broadcastable_binary_op
(
"div"
,
data_i
,
divisor
);
return
info
.
add_broadcastable_binary_op
(
"add"
,
mean
,
div
);
});
if
(
contains
(
float_types
,
args
[
0
]
->
get_shape
().
type
()))
{
return
std
::
accumulate
(
args
.
begin
()
+
1
,
args
.
end
(),
info
.
add_broadcastable_binary_op
(
"div"
,
args
[
0
],
divisor
),
[
&
](
auto
mean
,
auto
data_i
)
{
// Pre-divide each tensor element-wise by n to reduce risk of
// overflow during summation
auto
div
=
info
.
add_broadcastable_binary_op
(
"div"
,
data_i
,
divisor
);
return
info
.
add_broadcastable_binary_op
(
"add"
,
mean
,
div
);
});
}
else
{
// Compute sum before division for integral types
auto
sum
=
std
::
accumulate
(
args
.
begin
()
+
1
,
args
.
end
(),
args
[
0
],
[
&
](
auto
accum
,
auto
data_i
)
{
return
info
.
add_broadcastable_binary_op
(
"add"
,
accum
,
data_i
);
});
return
info
.
add_broadcastable_binary_op
(
"div"
,
sum
,
divisor
);
}
}
};
...
...
src/opt/memory_coloring.cpp
View file @
faefeef9
...
...
@@ -4,11 +4,11 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring
::
apply
(
module
&
p
)
const
void
memory_coloring
::
apply
(
module
&
m
)
const
{
if
(
!
enabled
(
MIGRAPHX_DISABLE_MEMORY_COLORING
{}))
{
memory_coloring_impl
opt
(
&
p
,
allocation_op
,
verify
);
memory_coloring_impl
opt
(
&
m
,
allocation_op
,
verify
);
opt
.
run
();
}
}
...
...
src/propagate_constant.cpp
View file @
faefeef9
...
...
@@ -20,9 +20,9 @@ bool skip_propogate(instruction_ref ins)
return
false
;
}
void
propagate_constant
::
apply
(
module
&
p
)
const
void
propagate_constant
::
apply
(
module
&
m
)
const
{
for
(
auto
i
:
iterator_for
(
p
))
for
(
auto
i
:
iterator_for
(
m
))
{
if
(
i
->
name
()
!=
"@literal"
)
continue
;
...
...
@@ -42,8 +42,8 @@ void propagate_constant::apply(module& p) const
if
(
not
r
.
empty
())
{
assert
(
r
.
get_shape
()
==
child
->
get_shape
());
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
self
(
p
.
replace_instruction
(
child
,
l
));
auto
l
=
m
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
self
(
m
.
replace_instruction
(
child
,
l
));
}
}
})(
i
);
...
...
src/py/migraphx_py.cpp
View file @
faefeef9
...
...
@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
arg
(
"op"
),
py
::
arg
(
"args"
),
py
::
arg
(
"mod_args"
)
=
std
::
vector
<
migraphx
::
module
*>
{})
.
def
(
"add_literal"
,
[](
migraphx
::
module
&
mm
,
py
::
buffer
data
)
{
py
::
buffer_info
info
=
data
.
request
();
auto
literal_shape
=
to_shape
(
info
);
return
mm
.
add_literal
(
literal_shape
,
reinterpret_cast
<
char
*>
(
info
.
ptr
));
},
py
::
arg
(
"data"
))
.
def
(
"add_parameter"
,
[](
migraphx
::
module
&
mm
,
const
std
::
string
&
name
,
const
migraphx
::
shape
shape
)
{
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment