Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
417d6644
Commit
417d6644
authored
May 20, 2022
by
charlie
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_conv
parents
79e27dac
4a312201
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
156 additions
and
52 deletions
+156
-52
src/include/migraphx/eliminate_allocation.hpp
src/include/migraphx/eliminate_allocation.hpp
+1
-1
src/include/migraphx/eliminate_common_subexpression.hpp
src/include/migraphx/eliminate_common_subexpression.hpp
+1
-1
src/include/migraphx/eliminate_concat.hpp
src/include/migraphx/eliminate_concat.hpp
+1
-1
src/include/migraphx/eliminate_contiguous.hpp
src/include/migraphx/eliminate_contiguous.hpp
+1
-1
src/include/migraphx/eliminate_identity.hpp
src/include/migraphx/eliminate_identity.hpp
+1
-1
src/include/migraphx/make_op.hpp
src/include/migraphx/make_op.hpp
+13
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+68
-4
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+1
-1
src/include/migraphx/propagate_constant.hpp
src/include/migraphx/propagate_constant.hpp
+1
-1
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+1
-1
src/include/migraphx/rewrite_pooling.hpp
src/include/migraphx/rewrite_pooling.hpp
+1
-1
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+12
-13
src/include/migraphx/schedule.hpp
src/include/migraphx/schedule.hpp
+1
-1
src/include/migraphx/simplify_algebra.hpp
src/include/migraphx/simplify_algebra.hpp
+1
-1
src/include/migraphx/simplify_reshapes.hpp
src/include/migraphx/simplify_reshapes.hpp
+1
-1
src/make_op.cpp
src/make_op.cpp
+28
-7
src/opt/memory_coloring.cpp
src/opt/memory_coloring.cpp
+2
-2
src/propagate_constant.cpp
src/propagate_constant.cpp
+4
-4
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+8
-0
src/rewrite_batchnorm.cpp
src/rewrite_batchnorm.cpp
+9
-9
No files found.
src/include/migraphx/eliminate_allocation.hpp
View file @
417d6644
...
@@ -19,7 +19,7 @@ struct eliminate_allocation
...
@@ -19,7 +19,7 @@ struct eliminate_allocation
std
::
string
allocation_op
{};
std
::
string
allocation_op
{};
std
::
size_t
alignment
=
32
;
std
::
size_t
alignment
=
32
;
std
::
string
name
()
const
{
return
"eliminate_allocation"
;
}
std
::
string
name
()
const
{
return
"eliminate_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_common_subexpression.hpp
View file @
417d6644
...
@@ -16,7 +16,7 @@ struct module;
...
@@ -16,7 +16,7 @@ struct module;
struct
eliminate_common_subexpression
struct
eliminate_common_subexpression
{
{
std
::
string
name
()
const
{
return
"eliminate_common_subexpression"
;
}
std
::
string
name
()
const
{
return
"eliminate_common_subexpression"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_concat.hpp
View file @
417d6644
...
@@ -18,7 +18,7 @@ struct eliminate_concat
...
@@ -18,7 +18,7 @@ struct eliminate_concat
{
{
concat_optimization
concat_opt
;
concat_optimization
concat_opt
;
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_contiguous.hpp
View file @
417d6644
...
@@ -17,7 +17,7 @@ struct eliminate_contiguous
...
@@ -17,7 +17,7 @@ struct eliminate_contiguous
{
{
std
::
string
op_name
;
std
::
string
op_name
;
std
::
string
name
()
const
{
return
"eliminate_contiguous"
;
}
std
::
string
name
()
const
{
return
"eliminate_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_identity.hpp
View file @
417d6644
...
@@ -18,7 +18,7 @@ struct module;
...
@@ -18,7 +18,7 @@ struct module;
struct
eliminate_identity
struct
eliminate_identity
{
{
std
::
string
name
()
const
{
return
"eliminate_identity"
;
}
std
::
string
name
()
const
{
return
"eliminate_identity"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/make_op.hpp
View file @
417d6644
...
@@ -9,7 +9,19 @@ namespace migraphx {
...
@@ -9,7 +9,19 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
);
operation
make_op
(
const
std
::
string
&
name
);
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
);
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
);
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template
<
class
Value
>
operation
make_op
(
const
std
::
string
&
name
,
const
Value
&
v
)
{
return
make_op_from_value
(
name
,
v
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
417d6644
...
@@ -156,6 +156,19 @@ struct id_matcher
...
@@ -156,6 +156,19 @@ struct id_matcher
}
}
};
};
// Forward declare class and constructors
template
<
class
M
>
struct
basic_matcher
;
template
<
class
M
>
basic_matcher
<
M
>
make_basic_matcher
(
M
m
);
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
);
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
);
/// The basic matcher provides the all_of composability of the matcher
/// The basic matcher provides the all_of composability of the matcher
template
<
class
M
>
template
<
class
M
>
struct
basic_matcher
struct
basic_matcher
...
@@ -167,8 +180,8 @@ struct basic_matcher
...
@@ -167,8 +180,8 @@ struct basic_matcher
{
{
// Copy m because we cant capture `this` by value
// Copy m because we cant capture `this` by value
auto
mm
=
m
;
auto
mm
=
m
;
return
make_b
f
_matcher
([
=
](
matcher_context
&
ctx
,
return
make_b
asic_fun
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
mm
.
match
(
ctx
,
ins
);
auto
result
=
mm
.
match
(
ctx
,
ins
);
if
(
result
)
if
(
result
)
{
{
...
@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
...
@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct
matcher_result
struct
matcher_result
{
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
struct
instruction_container
{
instruction_container
()
=
default
;
instruction_container
(
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
x
)
:
ins_map
(
std
::
move
(
x
))
{
}
instruction_ref
operator
[](
const
std
::
string
&
name
)
const
{
auto
it
=
ins_map
.
find
(
name
);
if
(
it
==
ins_map
.
end
())
MIGRAPHX_THROW
(
"Accessing name that wasn't bound in matcher: "
+
name
);
return
it
->
second
;
}
auto
find
(
const
std
::
string
&
name
)
const
{
return
ins_map
.
find
(
name
);
}
auto
begin
()
const
{
return
ins_map
.
cbegin
();
}
auto
end
()
const
{
return
ins_map
.
cend
();
}
bool
has_instructions_in
(
const
module
&
mod
)
const
{
return
std
::
all_of
(
ins_map
.
begin
(),
ins_map
.
end
(),
[
&
](
auto
&&
p
)
{
return
mod
.
has_instruction
(
p
.
second
);
});
}
private:
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
ins_map
;
};
instruction_container
instructions
;
instruction_ref
result
;
instruction_ref
result
;
};
};
...
@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
...
@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
{
result
.
result
=
ins
;
result
.
result
=
ins
;
result
.
instructions
=
ctx
.
instructions
;
result
.
instructions
=
ctx
.
instructions
;
assert
(
result
.
instructions
.
has_instructions_in
(
mod
));
}
}
else
else
{
{
...
@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
...
@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
});
});
}
}
inline
auto
var
(
std
::
string
s
)
{
return
make_basic_fun_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
instruction_ref
)
->
optional
<
instruction_ref
>
{
auto
it
=
ctx
.
instructions
.
find
(
s
);
if
(
it
==
ctx
.
instructions
.
end
())
return
nullopt
;
return
it
->
second
;
});
}
inline
auto
name
(
std
::
string
s
)
inline
auto
name
(
std
::
string
s
)
{
{
return
make_basic_pred_matcher
(
return
make_basic_pred_matcher
(
...
@@ -696,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
...
@@ -696,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
))(
ms
...);
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
))(
ms
...);
}
}
template
<
class
...
Ms
>
auto
skip_broadcasts_converts
(
Ms
...
ms
)
{
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
}
template
<
class
T
>
template
<
class
T
>
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
{
{
return
skip_broadcasts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
skip_broadcasts
_converts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
return
false
;
auto
l
=
ins
->
get_literal
();
auto
l
=
ins
->
get_literal
();
...
...
src/include/migraphx/memory_coloring.hpp
View file @
417d6644
...
@@ -17,7 +17,7 @@ struct memory_coloring
...
@@ -17,7 +17,7 @@ struct memory_coloring
std
::
string
allocation_op
{};
std
::
string
allocation_op
{};
bool
verify
=
false
;
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory coloring"
;
}
std
::
string
name
()
const
{
return
"memory coloring"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/propagate_constant.hpp
View file @
417d6644
...
@@ -15,7 +15,7 @@ struct module;
...
@@ -15,7 +15,7 @@ struct module;
struct
propagate_constant
struct
propagate_constant
{
{
std
::
string
name
()
const
{
return
"propagate_constant"
;
}
std
::
string
name
()
const
{
return
"propagate_constant"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_batchnorm.hpp
View file @
417d6644
...
@@ -16,7 +16,7 @@ struct module;
...
@@ -16,7 +16,7 @@ struct module;
struct
rewrite_batchnorm
struct
rewrite_batchnorm
{
{
std
::
string
name
()
const
{
return
"rewrite_batchnorm"
;
}
std
::
string
name
()
const
{
return
"rewrite_batchnorm"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_pooling.hpp
View file @
417d6644
...
@@ -15,7 +15,7 @@ struct module;
...
@@ -15,7 +15,7 @@ struct module;
struct
rewrite_pooling
struct
rewrite_pooling
{
{
std
::
string
name
()
const
{
return
"rewrite_pooling"
;
}
std
::
string
name
()
const
{
return
"rewrite_pooling"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
417d6644
...
@@ -19,22 +19,22 @@ struct module;
...
@@ -19,22 +19,22 @@ struct module;
struct
rewrite_rnn
struct
rewrite_rnn
{
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
private:
private:
// for vanilla rnn operators
// for vanilla rnn operators
void
apply_vanilla_rnn
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_vanilla_rnn
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
std
::
vector
<
instruction_ref
>
inputs
,
operation
&
actv_func
)
const
;
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
// for gru operators
void
apply_gru
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_gru
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
int
linear_before_reset
,
...
@@ -44,9 +44,9 @@ struct rewrite_rnn
...
@@ -44,9 +44,9 @@ struct rewrite_rnn
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
// for lstm operators
void
apply_lstm
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_lstm
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
const
operation
&
actv_func1
,
...
@@ -55,24 +55,23 @@ struct rewrite_rnn
...
@@ -55,24 +55,23 @@ struct rewrite_rnn
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
bool
is_variable_seq_lens
(
const
module
&
prog
,
instruction_ref
seq_lens
)
const
;
bool
is_variable_seq_lens
(
const
module
&
m
,
instruction_ref
seq_lens
)
const
;
instruction_ref
replace_last_hs_output
(
module
&
prog
,
instruction_ref
replace_last_hs_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
seq_lens
,
instruction_ref
last_hs_output
,
instruction_ref
last_hs_output
,
op
::
rnn_direction
dirct
)
const
;
op
::
rnn_direction
dirct
)
const
;
void
replace_last_cell_output
(
module
&
prog
,
void
replace_last_cell_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
seq_lens
,
instruction_ref
cell_outputs
,
instruction_ref
cell_outputs
,
instruction_ref
last_cell_output
,
instruction_ref
last_cell_output
,
op
::
rnn_direction
dirct
)
const
;
op
::
rnn_direction
dirct
)
const
;
std
::
size_t
std
::
size_t
get_seq_len
(
const
module
&
m
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
get_seq_len
(
const
module
&
prog
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
instruction_ref
pad_hidden_states
(
module
&
prog
,
instruction_ref
pad_hidden_states
(
module
&
m
,
instruction_ref
seq
,
instruction_ref
seq
,
instruction_ref
seq_lens
,
instruction_ref
seq_lens
,
instruction_ref
hs
)
const
;
instruction_ref
hs
)
const
;
...
...
src/include/migraphx/schedule.hpp
View file @
417d6644
...
@@ -19,7 +19,7 @@ struct schedule
...
@@ -19,7 +19,7 @@ struct schedule
schedule_model
model
{};
schedule_model
model
{};
bool
enable
=
true
;
bool
enable
=
true
;
std
::
string
name
()
const
{
return
"schedule"
;
}
std
::
string
name
()
const
{
return
"schedule"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/simplify_algebra.hpp
View file @
417d6644
...
@@ -15,7 +15,7 @@ struct module;
...
@@ -15,7 +15,7 @@ struct module;
struct
simplify_algebra
struct
simplify_algebra
{
{
std
::
string
name
()
const
{
return
"simplify_algebra"
;
}
std
::
string
name
()
const
{
return
"simplify_algebra"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/simplify_reshapes.hpp
View file @
417d6644
...
@@ -16,7 +16,7 @@ struct module;
...
@@ -16,7 +16,7 @@ struct module;
struct
simplify_reshapes
struct
simplify_reshapes
{
{
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/make_op.cpp
100755 → 100644
View file @
417d6644
...
@@ -5,20 +5,41 @@ namespace migraphx {
...
@@ -5,20 +5,41 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
)
{
return
load_op
(
name
);
}
operation
make_op
(
const
std
::
string
&
name
)
{
return
load_op
(
name
);
}
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
)
template
<
class
F
>
operation
make_op_generic
(
const
std
::
string
&
name
,
F
for_each
)
{
{
if
(
not
(
v
.
is_object
()
or
(
v
.
empty
()
and
v
.
is_array
())))
MIGRAPHX_THROW
(
"Value is not an object"
);
auto
op
=
load_op
(
name
);
auto
op
=
load_op
(
name
);
// Merge values
// Merge values
value
w
=
op
.
to_value
();
value
w
=
op
.
to_value
();
for
(
auto
&&
x
:
v
)
for_each
([
&
](
const
auto
&
key
,
const
auto
&
x
)
{
{
if
(
not
w
.
contains
(
key
))
w
.
at
(
x
.
get_key
())
=
x
.
without_key
();
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
}
MIGRAPHX_THROW
(
"No key '"
+
key
+
"' in "
+
name
);
w
.
at
(
key
)
=
x
;
});
op
.
from_value
(
w
);
op
.
from_value
(
w
);
return
op
;
return
op
;
}
}
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
)
{
return
make_op_generic
(
name
,
[
&
](
auto
f
)
{
for
(
auto
&&
[
key
,
x
]
:
v
)
f
(
key
,
x
);
});
}
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
)
{
if
(
not
(
v
.
is_object
()
or
(
v
.
empty
()
and
v
.
is_array
())))
MIGRAPHX_THROW
(
"Value is not an object for make_op: "
+
name
);
return
make_op_generic
(
name
,
[
&
](
auto
f
)
{
for
(
auto
&&
x
:
v
)
f
(
x
.
get_key
(),
x
.
without_key
());
});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/opt/memory_coloring.cpp
View file @
417d6644
...
@@ -4,11 +4,11 @@
...
@@ -4,11 +4,11 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring
::
apply
(
module
&
p
)
const
void
memory_coloring
::
apply
(
module
&
m
)
const
{
{
if
(
!
enabled
(
MIGRAPHX_DISABLE_MEMORY_COLORING
{}))
if
(
!
enabled
(
MIGRAPHX_DISABLE_MEMORY_COLORING
{}))
{
{
memory_coloring_impl
opt
(
&
p
,
allocation_op
,
verify
);
memory_coloring_impl
opt
(
&
m
,
allocation_op
,
verify
);
opt
.
run
();
opt
.
run
();
}
}
}
}
...
...
src/propagate_constant.cpp
View file @
417d6644
...
@@ -20,9 +20,9 @@ bool skip_propogate(instruction_ref ins)
...
@@ -20,9 +20,9 @@ bool skip_propogate(instruction_ref ins)
return
false
;
return
false
;
}
}
void
propagate_constant
::
apply
(
module
&
p
)
const
void
propagate_constant
::
apply
(
module
&
m
)
const
{
{
for
(
auto
i
:
iterator_for
(
p
))
for
(
auto
i
:
iterator_for
(
m
))
{
{
if
(
i
->
name
()
!=
"@literal"
)
if
(
i
->
name
()
!=
"@literal"
)
continue
;
continue
;
...
@@ -42,8 +42,8 @@ void propagate_constant::apply(module& p) const
...
@@ -42,8 +42,8 @@ void propagate_constant::apply(module& p) const
if
(
not
r
.
empty
())
if
(
not
r
.
empty
())
{
{
assert
(
r
.
get_shape
()
==
child
->
get_shape
());
assert
(
r
.
get_shape
()
==
child
->
get_shape
());
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
auto
l
=
m
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
self
(
p
.
replace_instruction
(
child
,
l
));
self
(
m
.
replace_instruction
(
child
,
l
));
}
}
}
}
})(
i
);
})(
i
);
...
...
src/py/migraphx_py.cpp
View file @
417d6644
...
@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
arg
(
"op"
),
py
::
arg
(
"op"
),
py
::
arg
(
"args"
),
py
::
arg
(
"args"
),
py
::
arg
(
"mod_args"
)
=
std
::
vector
<
migraphx
::
module
*>
{})
py
::
arg
(
"mod_args"
)
=
std
::
vector
<
migraphx
::
module
*>
{})
.
def
(
"add_literal"
,
[](
migraphx
::
module
&
mm
,
py
::
buffer
data
)
{
py
::
buffer_info
info
=
data
.
request
();
auto
literal_shape
=
to_shape
(
info
);
return
mm
.
add_literal
(
literal_shape
,
reinterpret_cast
<
char
*>
(
info
.
ptr
));
},
py
::
arg
(
"data"
))
.
def
(
.
def
(
"add_parameter"
,
"add_parameter"
,
[](
migraphx
::
module
&
mm
,
const
std
::
string
&
name
,
const
migraphx
::
shape
shape
)
{
[](
migraphx
::
module
&
mm
,
const
std
::
string
&
name
,
const
migraphx
::
shape
shape
)
{
...
...
src/rewrite_batchnorm.cpp
View file @
417d6644
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_batchnorm
::
apply
(
module
&
p
)
const
void
rewrite_batchnorm
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
continue
;
continue
;
...
@@ -46,13 +46,13 @@ void rewrite_batchnorm::apply(module& p) const
...
@@ -46,13 +46,13 @@ void rewrite_batchnorm::apply(module& p) const
});
});
auto
broadcast
=
op
::
broadcast
{
1
,
ins
->
get_shape
().
lens
()};
auto
broadcast
=
op
::
broadcast
{
1
,
ins
->
get_shape
().
lens
()};
auto
a_ins
=
p
.
add_literal
({
a
.
get_shape
(),
a
.
data
()});
auto
a_ins
=
m
.
add_literal
({
a
.
get_shape
(),
a
.
data
()});
auto
a_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
a_ins
);
auto
a_broadcast
=
m
.
insert_instruction
(
ins
,
broadcast
,
a_ins
);
auto
mul
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ins
->
inputs
().
front
(),
a_broadcast
);
auto
mul
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ins
->
inputs
().
front
(),
a_broadcast
);
auto
b_ins
=
p
.
add_literal
({
b
.
get_shape
(),
b
.
data
()});
auto
b_ins
=
m
.
add_literal
({
b
.
get_shape
(),
b
.
data
()});
auto
b_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
b_ins
);
auto
b_broadcast
=
m
.
insert_instruction
(
ins
,
broadcast
,
b_ins
);
auto
add
=
p
.
insert_instruction
(
ins
,
make_op
(
"add"
),
mul
,
b_broadcast
);
auto
add
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
mul
,
b_broadcast
);
p
.
replace_instruction
(
ins
,
add
);
m
.
replace_instruction
(
ins
,
add
);
}
}
}
}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment