Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2f268bc2
Commit
2f268bc2
authored
Jun 12, 2022
by
Paul
Browse files
Merge branch 'develop' into mlir-c
parents
f75c5a38
aa7ff911
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
306 additions
and
74 deletions
+306
-74
src/include/migraphx/rewrite_pooling.hpp
src/include/migraphx/rewrite_pooling.hpp
+1
-1
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+12
-13
src/include/migraphx/schedule.hpp
src/include/migraphx/schedule.hpp
+1
-1
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+0
-1
src/include/migraphx/simplify_algebra.hpp
src/include/migraphx/simplify_algebra.hpp
+1
-1
src/include/migraphx/simplify_reshapes.hpp
src/include/migraphx/simplify_reshapes.hpp
+1
-1
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+0
-2
src/include/migraphx/verify.hpp
src/include/migraphx/verify.hpp
+0
-1
src/make_op.cpp
src/make_op.cpp
+28
-7
src/module.cpp
src/module.cpp
+8
-1
src/onnx/parse_generic_op.cpp
src/onnx/parse_generic_op.cpp
+3
-2
src/onnx/parse_mean.cpp
src/onnx/parse_mean.cpp
+26
-7
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+17
-7
src/onnx/parse_reversesequence.cpp
src/onnx/parse_reversesequence.cpp
+125
-0
src/onnx/parse_scatter.cpp
src/onnx/parse_scatter.cpp
+44
-0
src/onnx/parse_squeeze.cpp
src/onnx/parse_squeeze.cpp
+1
-1
src/op_enums.cpp
src/op_enums.cpp
+1
-1
src/opt/memory_coloring.cpp
src/opt/memory_coloring.cpp
+2
-2
src/process.cpp
src/process.cpp
+1
-1
src/propagate_constant.cpp
src/propagate_constant.cpp
+34
-24
No files found.
src/include/migraphx/rewrite_pooling.hpp
View file @
2f268bc2
...
...
@@ -15,7 +15,7 @@ struct module;
struct
rewrite_pooling
{
std
::
string
name
()
const
{
return
"rewrite_pooling"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
2f268bc2
...
...
@@ -19,22 +19,22 @@ struct module;
struct
rewrite_rnn
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
private:
// for vanilla rnn operators
void
apply_vanilla_rnn
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_vanilla_rnn
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_gru
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
...
...
@@ -44,9 +44,9 @@ struct rewrite_rnn
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
void
apply_lstm
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_lstm
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
...
...
@@ -55,24 +55,23 @@ struct rewrite_rnn
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
bool
is_variable_seq_lens
(
const
module
&
prog
,
instruction_ref
seq_lens
)
const
;
instruction_ref
replace_last_hs_output
(
module
&
prog
,
bool
is_variable_seq_lens
(
const
module
&
m
,
instruction_ref
seq_lens
)
const
;
instruction_ref
replace_last_hs_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
last_hs_output
,
op
::
rnn_direction
dirct
)
const
;
void
replace_last_cell_output
(
module
&
prog
,
void
replace_last_cell_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
cell_outputs
,
instruction_ref
last_cell_output
,
op
::
rnn_direction
dirct
)
const
;
std
::
size_t
get_seq_len
(
const
module
&
prog
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
std
::
size_t
get_seq_len
(
const
module
&
m
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
instruction_ref
pad_hidden_states
(
module
&
prog
,
instruction_ref
pad_hidden_states
(
module
&
m
,
instruction_ref
seq
,
instruction_ref
seq_lens
,
instruction_ref
hs
)
const
;
...
...
src/include/migraphx/schedule.hpp
View file @
2f268bc2
...
...
@@ -19,7 +19,7 @@ struct schedule
schedule_model
model
{};
bool
enable
=
true
;
std
::
string
name
()
const
{
return
"schedule"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/serialize.hpp
View file @
2f268bc2
...
...
@@ -50,7 +50,6 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
value
result
=
value
::
array
{};
for
(
auto
&&
y
:
x
)
{
auto
e
=
to_value
(
y
);
result
.
insert
(
to_value
(
y
));
}
return
result
;
...
...
src/include/migraphx/simplify_algebra.hpp
View file @
2f268bc2
...
...
@@ -15,7 +15,7 @@ struct module;
struct
simplify_algebra
{
std
::
string
name
()
const
{
return
"simplify_algebra"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/simplify_reshapes.hpp
View file @
2f268bc2
...
...
@@ -16,7 +16,7 @@ struct module;
struct
simplify_reshapes
{
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/tensor_view.hpp
View file @
2f268bc2
...
...
@@ -120,10 +120,8 @@ struct tensor_view
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
}
// cppcheck-suppress functionConst
iterator
begin
()
{
return
{
0
,
{
this
}};
}
// cppcheck-suppress functionConst
iterator
end
()
{
return
{
this
->
size
(),
{
this
}};
}
const_iterator
begin
()
const
{
return
{
0
,
{
this
}};
}
...
...
src/include/migraphx/verify.hpp
View file @
2f268bc2
...
...
@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
{
double
threshold
=
std
::
numeric_limits
<
range_value
<
R1
>>::
epsilon
()
*
tolerance
;
auto
error
=
rms_range
(
r1
,
r2
);
// cppcheck-suppress uninitvar
if
(
out_error
!=
nullptr
)
*
out_error
=
error
;
return
error
<=
threshold
;
...
...
src/make_op.cpp
100755 → 100644
View file @
2f268bc2
...
...
@@ -5,20 +5,41 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
)
{
return
load_op
(
name
);
}
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
)
template
<
class
F
>
operation
make_op_generic
(
const
std
::
string
&
name
,
F
for_each
)
{
if
(
not
(
v
.
is_object
()
or
(
v
.
empty
()
and
v
.
is_array
())))
MIGRAPHX_THROW
(
"Value is not an object"
);
auto
op
=
load_op
(
name
);
// Merge values
value
w
=
op
.
to_value
();
for
(
auto
&&
x
:
v
)
{
w
.
at
(
x
.
get_key
())
=
x
.
without_key
();
}
for_each
([
&
](
const
auto
&
key
,
const
auto
&
x
)
{
if
(
not
w
.
contains
(
key
))
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
MIGRAPHX_THROW
(
"No key '"
+
key
+
"' in "
+
name
);
w
.
at
(
key
)
=
x
;
});
op
.
from_value
(
w
);
return
op
;
}
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
)
{
return
make_op_generic
(
name
,
[
&
](
auto
f
)
{
for
(
auto
&&
[
key
,
x
]
:
v
)
f
(
key
,
x
);
});
}
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
)
{
if
(
not
(
v
.
is_object
()
or
(
v
.
empty
()
and
v
.
is_array
())))
MIGRAPHX_THROW
(
"Value is not an object for make_op: "
+
name
);
return
make_op_generic
(
name
,
[
&
](
auto
f
)
{
for
(
auto
&&
x
:
v
)
f
(
x
.
get_key
(),
x
.
without_key
());
});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/module.cpp
View file @
2f268bc2
...
...
@@ -22,6 +22,8 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_FINALIZE
)
struct
module_impl
{
// A list is used to keep references to an instruction stable
...
...
@@ -555,8 +557,14 @@ instruction_ref module::find_dangling_reference() const
void
module
::
finalize
(
context
&
ctx
)
{
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_FINALIZE
{});
for
(
auto
ins
:
iterator_for
(
*
this
))
{
if
(
trace
)
{
std
::
cout
<<
"Finalize: "
;
this
->
debug_print
(
ins
);
}
ins
->
finalize
(
ctx
);
for
(
const
auto
&
smod
:
ins
->
module_inputs
())
{
...
...
@@ -731,7 +739,6 @@ std::unordered_map<instruction_ref, std::string>
module
::
print_cpp
(
std
::
ostream
&
os
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
{
os
<<
"migraphx::module p;"
<<
std
::
endl
;
// cppcheck-suppress variableScope
unsigned
long
seed
=
0
;
names
=
this
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
...
...
src/onnx/parse_generic_op.cpp
View file @
2f268bc2
...
...
@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
std
::
vector
<
op_desc
>
operators
()
const
{
// clang-format off
return
{{
"Abs"
,
"abs"
},
{
"Acos"
,
"acos"
},
{
"Acosh"
,
"acosh"
},
...
...
@@ -27,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Flatten"
,
"flatten"
},
{
"Floor"
,
"floor"
},
{
"Gather"
,
"gather"
},
{
"GatherND"
,
"gathernd"
},
{
"Identity"
,
"identity"
},
{
"IsNaN"
,
"isnan"
},
{
"LeakyRelu"
,
"leaky_relu"
},
...
...
@@ -37,8 +39,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Reciprocal"
,
"recip"
},
{
"Relu"
,
"relu"
},
{
"Round"
,
"round"
},
{
"Scatter"
,
"scatter"
},
{
"ScatterElements"
,
"scatter"
},
{
"Sigmoid"
,
"sigmoid"
},
{
"Sign"
,
"sign"
},
{
"Sin"
,
"sin"
},
...
...
@@ -47,6 +47,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Tan"
,
"tan"
},
{
"Tanh"
,
"tanh"
},
{
"Not"
,
"not"
}};
// clang-format on
}
bool
needs_contiguous
(
const
std
::
string
&
op_name
)
const
...
...
src/onnx/parse_mean.cpp
View file @
2f268bc2
...
...
@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -9,6 +10,9 @@ namespace onnx {
struct
parse_mean
:
op_parser
<
parse_mean
>
{
const
std
::
set
<
shape
::
type_t
>
float_types
=
{
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
};
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Mean"
}};
}
/// Calculates the element-wise mean of n>=1 input tensors
...
...
@@ -24,14 +28,29 @@ struct parse_mean : op_parser<parse_mean>
auto
divisor
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
args
[
0
]
->
get_shape
().
type
()},
{
num_data
}});
return
std
::
accumulate
(
args
.
begin
(),
args
.
end
(),
args
[
0
],
[
&
](
auto
&
mean
,
auto
&
data_i
)
{
// Pre-divide each tensor element-wise by n to reduce risk of overflow during summation
data_i
=
info
.
add_broadcastable_binary_op
(
"div"
,
data_i
,
divisor
);
if
(
contains
(
float_types
,
args
[
0
]
->
get_shape
().
type
()))
{
return
std
::
accumulate
(
args
.
begin
()
+
1
,
args
.
end
(),
info
.
add_broadcastable_binary_op
(
"div"
,
args
[
0
],
divisor
),
[
&
](
auto
mean
,
auto
data_i
)
{
// Pre-divide each tensor element-wise by n to reduce risk of
// overflow during summation
auto
div
=
info
.
add_broadcastable_binary_op
(
"div"
,
data_i
,
divisor
);
return
info
.
add_broadcastable_binary_op
(
"add"
,
mean
,
div
);
});
}
else
{
// Compute sum before division for integral types
auto
sum
=
std
::
accumulate
(
args
.
begin
()
+
1
,
args
.
end
(),
args
[
0
],
[
&
](
auto
accum
,
auto
data_i
)
{
return
info
.
add_broadcastable_binary_op
(
"add"
,
accum
,
data_i
);
});
if
(
data_i
!=
args
[
0
])
return
info
.
add_broadcastable_binary_op
(
"add"
,
mean
,
data_i
);
return
data_i
;
});
return
info
.
add_broadcastable_binary_op
(
"div"
,
sum
,
divisor
);
}
}
};
...
...
src/onnx/parse_pooling.cpp
View file @
2f268bc2
...
...
@@ -19,7 +19,9 @@ struct parse_pooling : op_parser<parse_pooling>
return
{{
"AveragePool"
,
"average"
},
{
"GlobalAveragePool"
,
"average"
},
{
"GlobalMaxPool"
,
"max"
},
{
"MaxPool"
,
"max"
}};
{
"MaxPool"
,
"max"
},
{
"LpPool"
,
"lpnorm"
},
{
"GlobalLpPool"
,
"lpnorm"
}};
}
instruction_ref
parse
(
const
op_desc
&
opd
,
...
...
@@ -27,14 +29,16 @@ struct parse_pooling : op_parser<parse_pooling>
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
std
::
string
mode
=
opd
.
op_name
;
if
(
mode
!=
"max"
&&
mode
!=
"average"
)
if
(
not
contains
(
mode_map
,
mode
)
)
{
MIGRAPHX_THROW
(
"onnx pooling mode must be
\"
max
\"
or
\"
average
\"
"
);
MIGRAPHX_THROW
(
"onnx pooling mode must be
[
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]
"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode
==
"average"
?
op
::
pooling_mode
::
average
:
op
::
pooling_mode
::
max
}});
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_lens
=
l0
->
get_shape
().
lens
();
...
...
@@ -74,6 +78,12 @@ struct parse_pooling : op_parser<parse_pooling>
kdims
,
values
[
"lengths"
].
size
(),
"PARSE_POOLING: inconsistent lengths"
);
}
// lp_order attribute
if
(
contains
(
info
.
attributes
,
"p"
))
{
values
[
"lp_order"
]
=
info
.
attributes
.
at
(
"p"
).
i
();
}
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"POOLING"
);
...
...
@@ -118,7 +128,7 @@ struct parse_pooling : op_parser<parse_pooling>
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
}
// used to calculate the supposed output shape
std
::
vector
<
int64_t
>
orig_padding
(
paddings
.
begin
(),
paddings
.
end
())
;
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
std
::
vector
<
int64_t
>
slice_start
;
std
::
vector
<
int64_t
>
slice_end
;
...
...
src/onnx/parse_reversesequence.cpp
0 → 100644
View file @
2f268bc2
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
//! Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
rest of the sequence in the original order. Variable sequence_lens is not supported in this
version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The
batch axis and time axis must be [0, 1] and not the same.
*/
struct
parse_reversesequence
:
op_parser
<
parse_reversesequence
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"ReverseSequence"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
int
batch_axis
=
1
;
if
(
contains
(
info
.
attributes
,
"batch_axis"
))
{
batch_axis
=
info
.
attributes
.
at
(
"batch_axis"
).
i
();
}
if
(
batch_axis
!=
0
and
batch_axis
!=
1
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: batch axis not 0 or 1"
);
}
int
time_axis
=
0
;
if
(
contains
(
info
.
attributes
,
"time_axis"
))
{
time_axis
=
info
.
attributes
.
at
(
"time_axis"
).
i
();
}
if
(
time_axis
!=
0
and
time_axis
!=
1
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: time axis not 0 or 1"
);
}
if
(
time_axis
==
batch_axis
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: time axis and batch axis are the same"
);
}
auto
input
=
args
[
0
];
auto
input_lens
=
input
->
get_shape
().
lens
();
if
(
input_lens
.
size
()
<
2
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: input tensor must have rank >= 2"
);
}
std
::
vector
<
int64_t
>
sequence_lens
;
if
(
args
.
size
()
==
2
)
{
migraphx
::
argument
seq_lens_arg
=
args
.
back
()
->
eval
();
check_arg_empty
(
seq_lens_arg
,
"REVERSESEQUENCE: cannot handle variable sequence_lens"
);
seq_lens_arg
.
visit
([
&
](
auto
s
)
{
sequence_lens
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
else
if
(
contains
(
info
.
attributes
,
"sequence_lens"
))
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"sequence_lens"
));
s
.
visit
([
&
](
auto
v
)
{
sequence_lens
.
assign
(
v
.
begin
(),
v
.
end
());
});
}
auto
batch_size
=
input_lens
[
batch_axis
];
auto
time_size
=
input_lens
[
time_axis
];
// this condition may still work if sequence_len's shape was incorrect
if
(
sequence_lens
.
size
()
!=
batch_size
)
{
MIGRAPHX_THROW
(
"REVERSESEQUENCE: sequence_lens has incorrect shape"
);
}
instruction_ref
ret
;
auto
add_slice
=
[
&
info
,
&
input
,
batch_axis
,
time_axis
](
int
b
,
int
t_start
,
int
t_end
)
{
return
info
.
add_instruction
(
make_op
(
"slice"
,
{{
"axes"
,
{
batch_axis
,
time_axis
}},
{
"starts"
,
{
b
,
t_start
}},
{
"ends"
,
{
b
+
1
,
t_end
}}}),
input
);
};
for
(
int
b
=
0
;
b
<
batch_size
;
++
b
)
{
instruction_ref
s0
;
if
(
sequence_lens
[
b
]
>
1
)
{
s0
=
add_slice
(
b
,
0
,
sequence_lens
[
b
]);
s0
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
{
time_axis
}}}),
s0
);
// if reversed less than whole batch, concat rest of batch
if
(
sequence_lens
[
b
]
<
time_size
)
{
auto
s1
=
add_slice
(
b
,
sequence_lens
[
b
],
time_size
);
s0
=
info
.
add_instruction
(
make_op
(
"concat"
,
{{
"axis"
,
time_axis
}}),
s0
,
s1
);
}
}
else
{
// cases where nothing changes
s0
=
add_slice
(
b
,
0
,
time_size
);
}
if
(
b
==
0
)
{
ret
=
s0
;
}
else
{
ret
=
info
.
add_instruction
(
make_op
(
"concat"
,
{{
"axis"
,
batch_axis
}}),
ret
,
s0
);
}
}
return
ret
;
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_scatter.cpp
0 → 100644
View file @
2f268bc2
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_scatter
:
op_parser
<
parse_scatter
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"ScatterElements"
},
{
"Scatter"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
operation
op
;
std
::
string
op_name
=
"scatter_none"
;
int
axis
=
0
;
if
(
contains
(
info
.
attributes
,
"axis"
))
axis
=
info
.
attributes
.
at
(
"axis"
).
i
();
if
(
contains
(
info
.
attributes
,
"reduction"
))
{
std
::
string
reduction_att
(
info
.
attributes
.
at
(
"reduction"
).
s
());
// check for a valid reduction attribute. We have an operator for each one.
if
(
not
contains
({
"none"
,
"add"
,
"mul"
},
reduction_att
))
MIGRAPHX_THROW
(
"PARSE_SCATTER: unsupported reduction mode "
+
reduction_att
);
// merge scatter with reduction attribute to specify which scatter operation. Future
// reduction op names should follow this pattern and should also be added to the check
// above.
op_name
=
std
::
string
(
"scatter_"
)
+
reduction_att
;
}
op
=
migraphx
::
make_op
(
op_name
,
{{
"axis"
,
axis
}});
return
info
.
add_instruction
(
op
,
args
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_squeeze.cpp
View file @
2f268bc2
...
...
@@ -30,11 +30,11 @@ struct parse_squeeze : op_parser<parse_squeeze>
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
op
=
parser
.
load
(
opd
.
op_name
,
info
);
std
::
vector
<
int64_t
>
axes
;
if
(
args
.
size
()
==
2
)
{
auto
arg_axes
=
args
.
at
(
1
)
->
eval
();
check_arg_empty
(
arg_axes
,
"PARSE_"
+
opd
.
op_name
+
": cannot handle variable axes!"
);
std
::
vector
<
int64_t
>
axes
;
arg_axes
.
visit
([
&
](
auto
s
)
{
axes
.
assign
(
s
.
begin
(),
s
.
end
());
});
op
=
assign_axes
(
op
,
axes
);
}
...
...
src/op_enums.cpp
View file @
2f268bc2
...
...
@@ -15,7 +15,7 @@ std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static
const
std
::
vector
<
std
::
string
>
pooling_mode_str
=
{
"average"
,
"max"
};
static
const
std
::
vector
<
std
::
string
>
pooling_mode_str
=
{
"average"
,
"max"
,
"lpnorm"
};
os
<<
pooling_mode_str
[
static_cast
<
std
::
underlying_type
<
pooling_mode
>::
type
>
(
v
)];
return
os
;
}
...
...
src/opt/memory_coloring.cpp
View file @
2f268bc2
...
...
@@ -4,11 +4,11 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring
::
apply
(
module
&
p
)
const
void
memory_coloring
::
apply
(
module
&
m
)
const
{
if
(
!
enabled
(
MIGRAPHX_DISABLE_MEMORY_COLORING
{}))
{
memory_coloring_impl
opt
(
&
p
,
allocation_op
,
verify
);
memory_coloring_impl
opt
(
&
m
,
allocation_op
,
verify
);
opt
.
run
();
}
}
...
...
src/process.cpp
View file @
2f268bc2
...
...
@@ -20,7 +20,6 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
int
ec
=
0
;
if
(
enabled
(
MIGRAPHX_TRACE_CMD_EXECUTE
{}))
std
::
cout
<<
cmd
<<
std
::
endl
;
std
::
array
<
char
,
128
>
buffer
;
auto
closer
=
[
&
](
FILE
*
stream
)
{
auto
status
=
pclose
(
stream
);
ec
=
WIFEXITED
(
status
)
?
0
:
WEXITSTATUS
(
status
);
// NOLINT
...
...
@@ -30,6 +29,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
std
::
unique_ptr
<
FILE
,
decltype
(
closer
)
>
pipe
(
popen
(
cmd
.
c_str
(),
"r"
),
closer
);
// NOLINT
if
(
!
pipe
)
MIGRAPHX_THROW
(
"popen() failed: "
+
cmd
);
std
::
array
<
char
,
128
>
buffer
;
while
(
fgets
(
buffer
.
data
(),
buffer
.
size
(),
pipe
.
get
())
!=
nullptr
)
std_out
(
buffer
.
data
());
}
...
...
src/propagate_constant.cpp
View file @
2f268bc2
...
...
@@ -3,6 +3,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <unordered_set>
namespace
migraphx
{
...
...
@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins)
return
false
;
}
void
propagate_constant
::
apply
(
module
&
p
)
const
bool
is_const
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_propogate
(
ins
);
}
void
propagate_constant
::
apply
(
module
&
m
)
const
{
for
(
auto
i
:
iterator_for
(
p
))
std
::
unordered_set
<
instruction_ref
>
const_instrs
;
auto
last
=
std
::
prev
(
m
.
end
());
// Find instructions that can be evaluated to a literal
for
(
auto
i
:
iterator_for
(
m
))
{
if
(
i
->
name
()
!=
"@literal"
)
if
(
i
s_const
(
i
)
and
i
!=
last
)
continue
;
if
(
i
->
outputs
().
empty
())
continue
;
fix
([
&
](
auto
self
,
auto
ins
)
{
std
::
unordered_set
<
instruction_ref
>
children
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
());
for
(
auto
child
:
children
)
{
if
(
child
->
name
()
==
"@literal"
or
skip_propogate
(
child
))
{
self
(
child
);
continue
;
}
auto
r
=
child
->
eval
();
if
(
not
r
.
empty
())
{
assert
(
r
.
get_shape
()
==
child
->
get_shape
());
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
self
(
p
.
replace_instruction
(
child
,
l
));
}
}
})(
i
);
std
::
copy_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
std
::
inserter
(
const_instrs
,
const_instrs
.
begin
()),
[
&
](
const
instruction_ref
ins
)
{
return
is_const
(
ins
)
and
ins
->
name
()
!=
"@literal"
;
});
}
// Compute literals in parallel
std
::
vector
<
instruction_ref
>
const_instrs_vec
{
const_instrs
.
begin
(),
const_instrs
.
end
()};
std
::
vector
<
argument
>
literals
(
const_instrs_vec
.
size
());
par_for
(
const_instrs_vec
.
size
(),
1
,
[
&
](
const
auto
i
)
{
literals
[
i
]
=
const_instrs_vec
[
i
]
->
eval
();
});
// Replace instructions in m
for
(
size_t
i
=
0
;
i
<
const_instrs_vec
.
size
();
i
++
)
{
if
(
not
literals
[
i
].
empty
())
{
assert
(
literals
[
i
].
get_shape
()
==
const_instrs_vec
[
i
]
->
get_shape
());
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
m
.
replace_instruction
(
const_instrs_vec
[
i
],
l
);
}
}
}
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment